From e6d3927567b415971d28f15b7bd89f364b576ccd Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Sat, 7 Sep 2024 12:12:27 +0200 Subject: [PATCH] Add a general torch CompositionModel (#280) --------- Co-authored-by: frostedoyster Co-authored-by: Guillaume Fraux --- docs/src/conf.py | 5 +- .../experimental/alchemical_model/trainer.py | 2 +- .../alchemical_model/utils/composition.py | 69 ++++ src/metatrain/experimental/gap/model.py | 22 +- src/metatrain/experimental/gap/trainer.py | 31 +- src/metatrain/experimental/soap_bpnn/model.py | 55 ++- .../soap_bpnn/tests/test_continue.py | 6 + .../soap_bpnn/tests/test_regression.py | 26 +- .../experimental/soap_bpnn/trainer.py | 71 +--- src/metatrain/utils/composition.py | 355 ++++++++++++----- tests/resources/generate-outputs.sh | 6 +- tests/utils/test_composition.py | 356 +++++++++++++++++- 12 files changed, 774 insertions(+), 230 deletions(-) create mode 100644 src/metatrain/experimental/alchemical_model/utils/composition.py diff --git a/docs/src/conf.py b/docs/src/conf.py index a0819459e..c9de1502a 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -5,14 +5,15 @@ import tomli # Replace by tomllib from std library once docs are build with Python 3.11 -import metatrain - # When importing metatensor-torch, this will change the definition of the classes # to include the documentation os.environ["METATENSOR_IMPORT_FOR_SPHINX"] = "1" os.environ["RASCALINE_IMPORT_FOR_SPHINX"] = "1" +import metatrain # noqa: E402 + + ROOT = os.path.abspath(os.path.join("..", "..")) # We use a second (pseudo) sphinx project located in `docs/generate_examples` to run the diff --git a/src/metatrain/experimental/alchemical_model/trainer.py b/src/metatrain/experimental/alchemical_model/trainer.py index 2ff0693dd..dddbc9839 100644 --- a/src/metatrain/experimental/alchemical_model/trainer.py +++ b/src/metatrain/experimental/alchemical_model/trainer.py @@ -5,7 +5,6 @@ import torch from metatensor.learn.data import DataLoader -from ...utils.composition import calculate_composition_weights from ...utils.data import ( CombinedDataLoader, Dataset, @@ -23,6 +22,7 @@ from ...utils.neighbor_lists import get_system_with_neighbor_lists from ...utils.per_atom import average_by_num_atoms from . import AlchemicalModel +from .utils.composition import calculate_composition_weights from .utils.normalize import ( get_average_number_of_atoms, get_average_number_of_neighbors, diff --git a/src/metatrain/experimental/alchemical_model/utils/composition.py b/src/metatrain/experimental/alchemical_model/utils/composition.py new file mode 100644 index 000000000..879672135 --- /dev/null +++ b/src/metatrain/experimental/alchemical_model/utils/composition.py @@ -0,0 +1,69 @@ +from typing import List, Tuple, Union + +import torch + +from ....utils.data.dataset import Dataset, get_atomic_types + + +def calculate_composition_weights( + datasets: Union[Dataset, List[Dataset]], property: str +) -> Tuple[torch.Tensor, List[int]]: + """Calculate the composition weights for a dataset. + + It assumes per-system properties. + + :param dataset: Dataset to calculate the composition weights for. + :returns: Composition weights for the dataset, as well as the + list of species that the weights correspond to. + """ + if not isinstance(datasets, list): + datasets = [datasets] + + # Note: `atomic_types` are sorted, and the composition weights are sorted as + # well, because the species are sorted in the composition features. + atomic_types = sorted(get_atomic_types(datasets)) + + targets = torch.stack( + [sample[property].block().values for dataset in datasets for sample in dataset] + ) + targets = targets.squeeze(dim=(1, 2)) # remove component and property dimensions + + total_num_structures = sum([len(dataset) for dataset in datasets]) + dtype = datasets[0][0]["system"].positions.dtype + composition_features = torch.empty( + (total_num_structures, len(atomic_types)), dtype=dtype + ) + structure_index = 0 + for dataset in datasets: + for sample in dataset: + structure = sample["system"] + for j, s in enumerate(atomic_types): + composition_features[structure_index, j] = torch.sum( + structure.types == s + ) + structure_index += 1 + + regularizer = 1e-20 + while regularizer: + if regularizer > 1e5: + raise RuntimeError( + "Failed to solve the linear system to calculate the " + "composition weights. The dataset is probably too small " + "or ill-conditioned." + ) + try: + solution = torch.linalg.solve( + composition_features.T @ composition_features + + regularizer + * torch.eye( + composition_features.shape[1], + dtype=composition_features.dtype, + device=composition_features.device, + ), + composition_features.T @ targets, + ) + break + except torch._C._LinAlgError: + regularizer *= 10.0 + + return solution, atomic_types diff --git a/src/metatrain/experimental/gap/model.py b/src/metatrain/experimental/gap/model.py index 97a24186c..1ce0a9cb5 100644 --- a/src/metatrain/experimental/gap/model.py +++ b/src/metatrain/experimental/gap/model.py @@ -21,6 +21,7 @@ from metatrain.utils.data.dataset import DatasetInfo +from ...utils.composition import CompositionModel from ...utils.export import export @@ -127,6 +128,11 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: ) self._species_labels: TorchLabels = TorchLabels.empty("_") + self.composition_model = CompositionModel( + model_hypers={}, + dataset_info=dataset_info, + ) + def restart(self, dataset_info: DatasetInfo) -> "GAP": raise ValueError("GAP does not allow restarting training") @@ -201,8 +207,20 @@ def forward( soap_features = TorchTensorMap(self._keys, soap_features.blocks()) output_key = list(outputs.keys())[0] energies = self._subset_of_regressors_torch(soap_features) - out_tensor = self.apply_composition_weights(systems, energies) - return {output_key: out_tensor} + return_dict = {output_key: energies} + + # apply composition model + composition_energies = self.composition_model( + systems, {output_key: ModelOutput("energy", per_atom=True)}, selected_atoms + ) + composition_energies[output_key] = metatensor.torch.sum_over_samples( + composition_energies[output_key], "atom" + ) + return_dict[output_key] = metatensor.torch.add( + return_dict[output_key], composition_energies[output_key] + ) + + return return_dict def export(self) -> MetatensorAtomisticModel: capabilities = ModelCapabilities( diff --git a/src/metatrain/experimental/gap/trainer.py b/src/metatrain/experimental/gap/trainer.py index a6e92809c..ffeb40951 100644 --- a/src/metatrain/experimental/gap/trainer.py +++ b/src/metatrain/experimental/gap/trainer.py @@ -9,7 +9,7 @@ from metatrain.utils.data import Dataset -from ...utils.composition import calculate_composition_weights +from ...utils.composition import remove_composition from ...utils.data import check_datasets from . import GAP from .model import torch_tensor_map_to_core @@ -52,10 +52,7 @@ def train( # Calculate and set the composition weights: logger.info("Calculating composition weights") - composition_weights, species = calculate_composition_weights( - train_datasets, target_name - ) - model.set_composition_weights(target_name, composition_weights, species) + model.composition_model.train_model(train_datasets) logger.info("Setting up data loaders") if len(train_datasets[0][0][output_name].keys) > 1: @@ -72,26 +69,10 @@ def train( model._keys = train_y.keys train_structures = [sample["system"] for sample in train_dataset] - logger.info("Fitting composition energies") - composition_energies = torch.zeros(len(train_y.block().values), dtype=dtype) - for i, structure in enumerate(train_structures): - for j, s in enumerate(species): - composition_energies[i] += ( - torch.sum(structure.types == s) * composition_weights[j] - ) - train_y_values = train_y.block().values - train_y_values = train_y_values - composition_energies.reshape(-1, 1) - train_block = metatensor.torch.TensorBlock( - values=train_y_values, - samples=train_y.block().samples, - components=train_y.block().components, - properties=train_y.block().properties, - ) - if len(train_y[0].gradients_list()) > 0: - train_block.add_gradient("positions", train_y[0].gradient("positions")) - train_y = metatensor.torch.TensorMap( - train_y.keys, - [train_block], + logger.info("Subtracting composition energies") + # this acts in-place on train_y + remove_composition( + train_structures, {target_name: train_y}, model.composition_model ) logger.info("Calculating SOAP features") diff --git a/src/metatrain/experimental/soap_bpnn/model.py b/src/metatrain/experimental/soap_bpnn/model.py index 7f144d27f..4b99f9229 100644 --- a/src/metatrain/experimental/soap_bpnn/model.py +++ b/src/metatrain/experimental/soap_bpnn/model.py @@ -16,7 +16,7 @@ from metatrain.utils.data.dataset import DatasetInfo -from ...utils.composition import apply_composition_contribution +from ...utils.composition import CompositionModel from ...utils.dtype import dtype_to_str from ...utils.export import export @@ -123,14 +123,6 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: unit="unitless", per_atom=True ) - # creates a composition weight tensor that can be directly indexed by species, - # this can be left as a tensor of zero or set from the outside using - # set_composition_weights (recommended for better accuracy) - n_outputs = len(self.outputs) - self.register_buffer( - "composition_weights", - torch.zeros((n_outputs, max(self.atomic_types) + 1)), - ) # buffers cannot be indexed by strings (torchscript), so we create a single # tensor for all output. Due to this, we need to slice the tensor when we use # it and use the output name to select the correct slice via a dictionary @@ -195,6 +187,11 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: } ) + self.composition_model = CompositionModel( + model_hypers={}, + dataset_info=dataset_info, + ) + def restart(self, dataset_info: DatasetInfo) -> "SoapBpnn": # merge old and new dataset info merged_info = self.dataset_info.union(dataset_info) @@ -261,12 +258,7 @@ def forward( atomic_energies: Dict[str, TensorMap] = {} for output_name, output_layer in self.last_layers.items(): if output_name in outputs: - atomic_energies[output_name] = apply_composition_contribution( - output_layer(last_layer_features), - self.composition_weights[ # type: ignore - self.output_to_index[output_name] - ], - ) + atomic_energies[output_name] = output_layer(last_layer_features) # Sum the atomic energies coming from the BPNN to get the total energy for output_name, atomic_energy in atomic_energies.items(): @@ -281,6 +273,19 @@ def forward( atomic_energy, ["atom", "center_type"] ) + if not self.training: + # at evaluation, we also add the composition contributions + composition_contributions = self.composition_model( + systems, outputs, selected_atoms + ) + for name in return_dict: + if name.startswith("mtt::aux::"): + continue # skip auxiliary outputs (not targets) + return_dict[name] = metatensor.torch.add( + return_dict[name], + composition_contributions[name], + ) + return return_dict @classmethod @@ -303,6 +308,11 @@ def export(self) -> MetatensorAtomisticModel: if dtype not in self.__supported_dtypes__: raise ValueError(f"unsupported dtype {self.dtype} for SoapBpnn") + # Make sure the model is all in the same dtype + # For example, at this point, the composition model within the SOAP-BPNN is + # still float64 + self.to(dtype) + capabilities = ModelCapabilities( outputs=self.outputs, atomic_types=self.atomic_types, @@ -314,21 +324,6 @@ def export(self) -> MetatensorAtomisticModel: return export(model=self, model_capabilities=capabilities) - def set_composition_weights( - self, - output_name: str, - input_composition_weights: torch.Tensor, - atomic_types: List[int], - ) -> None: - """Set the composition weights for a given output.""" - # all species that are not present retain their weight of zero - self.composition_weights[self.output_to_index[output_name]][ # type: ignore - atomic_types - ] = input_composition_weights.to( - dtype=self.composition_weights.dtype, # type: ignore - device=self.composition_weights.device, # type: ignore - ) - def add_output(self, output_name: str) -> None: """Add a new output to the self.""" # add a new row to the composition weights tensor diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_continue.py b/src/metatrain/experimental/soap_bpnn/tests/test_continue.py index dac96cd94..fc732897e 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_continue.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_continue.py @@ -44,6 +44,9 @@ def test_continue(monkeypatch, tmp_path): } } targets, _ = read_targets(OmegaConf.create(conf)) + + # systems in float64 are required for training + systems = [system.to(torch.float64) for system in systems] dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) hypers = DEFAULT_HYPERS.copy() @@ -63,6 +66,9 @@ def test_continue(monkeypatch, tmp_path): checkpoint_dir=".", ) + # evaluation + systems = [system.to(torch.float32) for system in systems] + # Predict on the first five systems output_before = model_before( systems[:5], {"mtt::U0": model_before.outputs["mtt::U0"]} diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_regression.py b/src/metatrain/experimental/soap_bpnn/tests/test_regression.py index 0663da485..7b4161ddb 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_regression.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_regression.py @@ -39,12 +39,18 @@ def test_regression_init(): ) expected_output = torch.tensor( - [[-0.03860], [0.11137], [0.09112], [-0.05634], [-0.02549]] + [ + [-0.038599025458], + [0.111374437809], + [0.091115802526], + [-0.056339077652], + [-0.025491207838], + ] ) # if you need to change the hardcoded values: - # torch.set_printoptions(precision=12) - # print(output["mtt::U0"].block().values) + torch.set_printoptions(precision=12) + print(output["mtt::U0"].block().values) torch.testing.assert_close( output["mtt::U0"].block().values, expected_output, rtol=1e-5, atol=1e-5 @@ -100,17 +106,17 @@ def test_regression_train(): expected_output = torch.tensor( [ - [-40.592571258545], - [-56.522350311279], - [-76.571365356445], - [-77.384849548340], - [-93.445365905762], + [-0.106249026954], + [0.039981484413], + [-0.142682999372], + [-0.031701669097], + [-0.016210660338], ] ) # if you need to change the hardcoded values: - # torch.set_printoptions(precision=12) - # print(output["mtt::U0"].block().values) + torch.set_printoptions(precision=12) + print(output["mtt::U0"].block().values) torch.testing.assert_close( output["mtt::U0"].block().values, expected_output, rtol=1e-5, atol=1e-5 diff --git a/src/metatrain/experimental/soap_bpnn/trainer.py b/src/metatrain/experimental/soap_bpnn/trainer.py index 4a342b158..47cc09174 100644 --- a/src/metatrain/experimental/soap_bpnn/trainer.py +++ b/src/metatrain/experimental/soap_bpnn/trainer.py @@ -7,14 +7,8 @@ import torch.distributed from torch.utils.data import DataLoader, DistributedSampler -from ...utils.composition import calculate_composition_weights -from ...utils.data import ( - CombinedDataLoader, - Dataset, - TargetInfoDict, - collate_fn, - get_all_targets, -) +from ...utils.composition import remove_composition +from ...utils.data import CombinedDataLoader, Dataset, TargetInfoDict, collate_fn from ...utils.data.extract_targets import get_targets_dict from ...utils.distributed.distributed_data_parallel import DistributedDataParallel from ...utils.distributed.slurm import DistributedEnvironment @@ -90,58 +84,21 @@ def train( logger.info(f"Training on {world_size} devices with dtype {dtype}") else: logger.info(f"Training on device {device} with dtype {dtype}") + + # Move the model to the device and dtype: model.to(device=device, dtype=dtype) - if is_distributed: - model = DistributedDataParallel(model, device_ids=[device]) + # The composition model of the SOAP-BPNN is always on CPU (to avoid OOM + # errors during the linear algebra training) and in float64 (to avoid + # numerical errors in the composition weights, which can be very large). + model.composition_model.to(device=torch.device("cpu"), dtype=torch.float64) - # Calculate and set the composition weights for all targets: logger.info("Calculating composition weights") - for target_name in (model.module if is_distributed else model).new_outputs: - if "mtt::aux::" in target_name: - continue - # TODO: document transfer learning and say that outputs that are already - # present in the model will keep their composition weights - if target_name in self.hypers["fixed_composition_weights"].keys(): - logger.info( - f"For {target_name}, model will use " - "user-supplied composition weights" - ) - cur_weight_dict = self.hypers["fixed_composition_weights"][target_name] - atomic_types = [] - num_species = len(cur_weight_dict) - fixed_weights = torch.zeros(num_species, dtype=dtype, device=device) - - for ii, (key, weight) in enumerate(cur_weight_dict.items()): - atomic_types.append(key) - fixed_weights[ii] = weight - - if ( - not set(atomic_types) - == (model.module if is_distributed else model).atomic_types - ): - raise ValueError( - "Supplied atomic types are not present in the dataset." - ) - (model.module if is_distributed else model).set_composition_weights( - target_name, fixed_weights, atomic_types - ) + model.composition_model.train_model( + train_datasets, self.hypers["fixed_composition_weights"] + ) - else: - train_datasets_with_target = [] - for dataset in train_datasets: - if target_name in get_all_targets(dataset): - train_datasets_with_target.append(dataset) - if len(train_datasets_with_target) == 0: - raise ValueError( - f"Target {target_name} in the model's new capabilities is not " - "present in any of the training datasets." - ) - composition_weights, composition_types = calculate_composition_weights( - train_datasets_with_target, target_name - ) - (model.module if is_distributed else model).set_composition_weights( - target_name, composition_weights, composition_types - ) + if is_distributed: + model = DistributedDataParallel(model, device_ids=[device]) logger.info("Setting up data loaders") @@ -273,6 +230,7 @@ def train( optimizer.zero_grad() systems, targets = batch + remove_composition(systems, targets, model.composition_model) systems = [system.to(dtype=dtype, device=device) for system in systems] targets = { key: value.to(dtype=dtype, device=device) @@ -312,6 +270,7 @@ def train( val_loss = 0.0 for batch in val_dataloader: systems, targets = batch + remove_composition(systems, targets, model.composition_model) systems = [system.to(dtype=dtype, device=device) for system in systems] targets = { key: value.to(dtype=dtype, device=device) diff --git a/src/metatrain/utils/composition.py b/src/metatrain/utils/composition.py index be2cca5a3..1f96e0549 100644 --- a/src/metatrain/utils/composition.py +++ b/src/metatrain/utils/composition.py @@ -1,105 +1,290 @@ -from typing import List, Tuple, Union +import warnings +from typing import Dict, List, Optional, Union +import metatensor.torch import torch from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.atomistic import ModelOutput, System -from metatrain.utils.data import Dataset, get_atomic_types +from .data import Dataset, DatasetInfo, get_all_targets, get_atomic_types +from .jsonschema import validate -def calculate_composition_weights( - datasets: Union[Dataset, List[Dataset]], property: str -) -> Tuple[torch.Tensor, List[int]]: - """Calculate the composition weights for a dataset. +class CompositionModel(torch.nn.Module): + """A simple model that calculates the energy based on the stoichiometry in a system. - For now, it assumes per-system properties. + :param model_hypers: A dictionary of model hyperparameters. The paramater is ignored + and is only present to be consistent with the general model API. + :param dataset_info: An object containing information about the dataset, including + target quantities and atomic types. - :param dataset: Dataset to calculate the composition weights for. - :returns: Composition weights for the dataset, as well as the - list of species that the weights correspond to. + :raises ValueError: If any target quantity in the dataset info is not an energy-like + quantity. """ - if not isinstance(datasets, list): - datasets = [datasets] - - # Note: `atomic_types` are sorted, and the composition weights are sorted as - # well, because the species are sorted in the composition features. - atomic_types = sorted(get_atomic_types(datasets)) - - targets = torch.stack( - [sample[property].block().values for dataset in datasets for sample in dataset] - ) - targets = targets.squeeze(dim=(1, 2)) # remove component and property dimensions - - total_num_structures = sum([len(dataset) for dataset in datasets]) - dtype = datasets[0][0]["system"].positions.dtype - composition_features = torch.empty( - (total_num_structures, len(atomic_types)), dtype=dtype - ) - structure_index = 0 - for dataset in datasets: - for sample in dataset: - structure = sample["system"] - for j, s in enumerate(atomic_types): - composition_features[structure_index, j] = torch.sum( - structure.types == s + + def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo): + super().__init__() + + # `model_hypers` should be an empty dictionary + validate( + instance=model_hypers, + schema={"type": "object", "additionalProperties": False}, + ) + + # Check capabilities + for target in dataset_info.targets.values(): + if target.quantity != "energy": + raise ValueError( + "CompositionModel only supports energy-like outputs, but a " + f"{target.quantity} output was provided." ) - structure_index += 1 - - regularizer = 1e-20 - while regularizer: - if regularizer > 1e5: - raise RuntimeError( - "Failed to solve the linear system to calculate the " - "composition weights. The dataset is probably too small " - "or ill-conditioned." + + self.dataset_info = dataset_info + self.atomic_types = sorted(dataset_info.atomic_types) + + n_types = len(self.atomic_types) + n_targets = len(dataset_info.targets) + + self.output_to_output_index = { + target: i for i, target in enumerate(sorted(dataset_info.targets.keys())) + } + + self.register_buffer( + "weights", torch.zeros((n_targets, n_types), dtype=torch.float64) + ) + + def train_model( + self, + datasets: List[Union[Dataset, torch.utils.data.Subset]], + fixed_weights: Optional[Dict[str, Dict[int, str]]] = None, + ) -> None: + """Train/fit the composition weights for the datasets. + + :param datasets: Dataset(s) to calculate the composition weights for. + :param fixed_weights: Optional fixed weights to use for the composition model, + for one or more target quantities. + + :raises ValueError: If the provided datasets contain unknown atomic types. + :raises RuntimeError: If the linear system to calculate the composition weights + cannot be solved. + """ + if not isinstance(datasets, list): + datasets = [datasets] + + if fixed_weights is None: + fixed_weights = {} + + additional_types = sorted( + set(get_atomic_types(datasets)) - set(self.atomic_types) + ) + if additional_types: + raise ValueError( + "Provided `datasets` contains unknown " + f"atomic types {additional_types}. " + f"Known types from initilaization are {self.atomic_types}." ) - try: - solution = torch.linalg.solve( - composition_features.T @ composition_features - + regularizer - * torch.eye( - composition_features.shape[1], - dtype=composition_features.dtype, - device=composition_features.device, - ), - composition_features.T @ targets, + + missing_types = sorted(set(self.atomic_types) - set(get_atomic_types(datasets))) + if missing_types: + warnings.warn( + f"Provided `datasets` do not contain atomic types {missing_types}. " + f"Known types from initilaization are {self.atomic_types}.", + stacklevel=2, ) - break - except torch._C._LinAlgError: - regularizer *= 10.0 - return solution, atomic_types + # Fill the weights for each target in the dataset info + for target_key in self.output_to_output_index.keys(): + if target_key in fixed_weights: + # The fixed weights are provided for this target. Use them: + if not sorted(fixed_weights[target_key].keys()) == self.atomic_types: + raise ValueError( + f"Fixed weights for target {target_key} must contain all " + f"atomic types {self.atomic_types}." + ) -def apply_composition_contribution( - atomic_property: TensorMap, composition_weights: torch.Tensor -) -> TensorMap: - """Apply the composition contribution to an atomic property. + self.weights[self.output_to_output_index[target_key]] = torch.tensor( + [fixed_weights[target_key][i] for i in self.atomic_types], + dtype=self.weights.dtype, + ) + else: + datasets_with_target = [] + for dataset in datasets: + if target_key in get_all_targets(dataset): + datasets_with_target.append(dataset) + if len(datasets_with_target) == 0: + raise ValueError( + f"Target {target_key} in the model's new capabilities is not " + "present in any of the training datasets." + ) - :param atomic_property: Atomic property to apply the composition contribution to. - :param composition_weights: Composition weights to apply. - :returns: Atomic property with the composition contribution applied. - """ + targets = torch.stack( + [ + sample[target_key].block().values + for dataset in datasets_with_target + for sample in dataset + ] + ) + + # remove component and property dimensions + targets = targets.squeeze(dim=(1, 2)) + + total_num_structures = sum( + [len(dataset) for dataset in datasets_with_target] + ) + dtype = datasets[0][0]["system"].positions.dtype + if dtype != torch.float64: + raise ValueError( + "The composition model only supports float64 during training. " + f"Got dtype: {dtype}." + ) + + composition_features = torch.zeros( + (total_num_structures, len(self.atomic_types)), dtype=dtype + ) + structure_index = 0 + for dataset in datasets_with_target: + for sample in dataset: + structure = sample["system"] + for j, t in enumerate(self.atomic_types): + composition_features[structure_index, j] = torch.sum( + structure.types == t + ) + structure_index += 1 + + regularizer = 1e-20 + while regularizer: + if regularizer > 1e5: + raise RuntimeError( + "Failed to solve the linear system to calculate the " + "composition weights. The dataset is probably too small or " + "ill-conditioned." + ) + try: + self.weights[self.output_to_output_index[target_key]] = ( + torch.linalg.solve( + composition_features.T @ composition_features + + regularizer + * torch.eye( + composition_features.shape[1], + dtype=composition_features.dtype, + device=composition_features.device, + ), + composition_features.T @ targets, + ).to(self.weights.dtype) + ) + break + except torch._C._LinAlgError: + regularizer *= 10.0 + + def restart(self, dataset_info: DatasetInfo) -> "CompositionModel": + """Restart the model with a new dataset info. - new_keys: List[int] = [] - new_blocks: List[TensorBlock] = [] - for key, block in atomic_property.items(): - atomic_type = int(key.values.item()) - new_keys.append(atomic_type) - new_values = block.values + composition_weights[atomic_type] - new_blocks.append( - TensorBlock( - values=new_values, - samples=block.samples, - components=block.components, - properties=block.properties, + :param dataset_info: New dataset information to be used. + """ + return self({}, self.dataset_info.union(dataset_info)) + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + """Compute the targets for each system based on the composition weights. + + :param systems: List of systems to calculate the energy per atom. + :param outputs: Dictionary containing the model outputs. + :param selected_atoms: Optional selection of atoms for which to compute the + targets. + :returns: A dictionary with the computed targets for each system. + + :raises ValueError: If no weights have been computed or if `outputs` keys + contain unsupported keys. + :raises NotImplementedError: If `selected_atoms` is provided (not implemented). + """ + dtype = systems[0].positions.dtype + device = systems[0].positions.device + + for output_name in outputs: + if output_name.startswith("mtt::aux::"): + continue + if output_name not in self.output_to_output_index: + raise ValueError( + f"output key {output_name} is not supported by this composition " + "model." + ) + + # Compute the targets for each system by adding the composition weights times + # number of atoms per atomic type. + targets_out: Dict[str, TensorMap] = {} + for target_key, target in outputs.items(): + if target_key.startswith("mtt::aux::"): + continue + weights = self.weights[self.output_to_output_index[target_key]] + targets_list = [] + sample_values: List[List[int]] = [] + + for i_system, system in enumerate(systems): + targets_single = torch.zeros(len(system), dtype=dtype, device=device) + + for i_type, atomic_type in enumerate(self.atomic_types): + targets_single[atomic_type == system.types] = weights[i_type] + + targets_list.append(targets_single) + sample_values += [[i_system, i_atom] for i_atom in range(len(system))] + + targets = torch.concatenate(targets_list) + + block = TensorBlock( + values=targets.reshape(-1, 1), + samples=Labels( + ["system", "atom"], torch.tensor(sample_values, device=device) + ), + components=[], + properties=Labels( + names=["energy"], values=torch.tensor([[0]], device=device) + ), ) - ) - new_keys_labels = Labels( - names=["center_type"], - values=torch.tensor(new_keys, device=new_blocks[0].values.device).reshape( - -1, 1 - ), - ) + targets_out[target_key] = TensorMap( + keys=Labels(names=["_"], values=torch.tensor([[0]], device=device)), + blocks=[block], + ) + + # apply selected_atoms to the composition if needed + if selected_atoms is not None: + targets_out[target_key] = metatensor.torch.slice( + targets_out[target_key], "samples", selected_atoms + ) + + if not target.per_atom: + targets_out[target_key] = metatensor.torch.sum_over_samples( + targets_out[target_key], sample_names="atom" + ) + + return targets_out - return TensorMap(keys=new_keys_labels, blocks=new_blocks) + +def remove_composition( + systems: List[System], + targets: Dict[str, TensorMap], + composition_model: torch.nn.Module, +): + """Remove the composition contribution from the training targets. + + The targets are changed in place. + + :param systems: List of systems. + :param targets: Dictionary containing the targets corresponding to the systems. + :param composition_model: The composition model used to calculate the composition + contribution. + """ + output_options = {} + for target_key in targets: + output_options[target_key] = ModelOutput(per_atom=False) + + composition_targets = composition_model(systems, output_options) + for target_key in targets: + targets[target_key].block().values[:] -= ( + composition_targets[target_key].block().values + ) diff --git a/tests/resources/generate-outputs.sh b/tests/resources/generate-outputs.sh index cd4ea8125..5f41011bb 100755 --- a/tests/resources/generate-outputs.sh +++ b/tests/resources/generate-outputs.sh @@ -1,5 +1,5 @@ #!/bin/bash -set -e +set -eux echo "Generating data for testing..." @@ -7,5 +7,5 @@ ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) cd $ROOT_DIR -mtt train options.yaml -o model-32-bit.pt -r base_precision=32 > /dev/null -mtt train options.yaml -o model-64-bit.pt -r base_precision=64 > /dev/null +mtt train options.yaml -o model-32-bit.pt -r base_precision=32 # > /dev/null +mtt train options.yaml -o model-64-bit.pt -r base_precision=64 # > /dev/null diff --git a/tests/utils/test_composition.py b/tests/utils/test_composition.py index 63a75e646..780744664 100644 --- a/tests/utils/test_composition.py +++ b/tests/utils/test_composition.py @@ -1,17 +1,21 @@ from pathlib import Path +import metatensor.torch +import pytest import torch from metatensor.torch import Labels, TensorBlock, TensorMap -from metatensor.torch.atomistic import System +from metatensor.torch.atomistic import ModelOutput, System +from omegaconf import OmegaConf -from metatrain.utils.composition import calculate_composition_weights -from metatrain.utils.data import Dataset +from metatrain.utils.composition import CompositionModel, remove_composition +from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data.readers import read_systems, read_targets RESOURCES_PATH = Path(__file__).parents[1] / "resources" -def test_calculate_composition_weights(): +def test_composition_model_train(): """Test the calculation of composition weights.""" # Here we use three synthetic structures: @@ -22,14 +26,16 @@ def test_calculate_composition_weights(): systems = [ System( - positions=torch.tensor([[0.0, 0.0, 0.0]]), + positions=torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float64), types=torch.tensor([8]), - cell=torch.eye(3), + cell=torch.eye(3, dtype=torch.float64), ), System( - positions=torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]), + positions=torch.tensor( + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype=torch.float64 + ), types=torch.tensor([1, 1, 8]), - cell=torch.eye(3), + cell=torch.eye(3, dtype=torch.float64), ), System( positions=torch.tensor( @@ -40,10 +46,11 @@ def test_calculate_composition_weights(): [0.0, 0.0, 1.0], [1.0, 0.0, 1.0], [0.0, 1.0, 1.0], - ] + ], + dtype=torch.float64, ), types=torch.tensor([1, 1, 8, 1, 1, 8]), - cell=torch.eye(3), + cell=torch.eye(3, dtype=torch.float64), ), ] energies = [1.0, 5.0, 10.0] @@ -52,7 +59,7 @@ def test_calculate_composition_weights(): keys=Labels(names=["_"], values=torch.tensor([[0]])), blocks=[ TensorBlock( - values=torch.tensor([[e]]), + values=torch.tensor([[e]], dtype=torch.float64), samples=Labels(names=["system"], values=torch.tensor([[i]])), components=[], properties=Labels(names=["energy"], values=torch.tensor([[0]])), @@ -63,9 +70,326 @@ def test_calculate_composition_weights(): ] dataset = Dataset.from_dict({"system": systems, "energy": energies}) - weights, atomic_types = calculate_composition_weights(dataset, "energy") + composition_model = CompositionModel( + model_hypers={}, + dataset_info=DatasetInfo( + length_unit="angstrom", + atomic_types=[1, 8], + targets=TargetInfoDict( + { + "energy": TargetInfo( + quantity="energy", + per_atom=False, + ) + } + ), + ), + ) + + composition_model.train_model(dataset) + assert composition_model.weights.shape[0] == 1 + assert composition_model.weights.shape[1] == 2 + assert composition_model.output_to_output_index == {"energy": 0} + assert composition_model.atomic_types == [1, 8] + torch.testing.assert_close( + composition_model.weights, torch.tensor([[2.0, 1.0]], dtype=torch.float64) + ) + + composition_model.train_model([dataset]) + assert composition_model.weights.shape[0] == 1 + assert composition_model.weights.shape[1] == 2 + assert composition_model.output_to_output_index == {"energy": 0} + assert composition_model.atomic_types == [1, 8] + torch.testing.assert_close( + composition_model.weights, torch.tensor([[2.0, 1.0]], dtype=torch.float64) + ) - assert len(weights) == len(atomic_types) - assert len(weights) == 2 - assert atomic_types == [1, 8] - torch.testing.assert_close(weights, torch.tensor([2.0, 1.0])) + composition_model.train_model([dataset, dataset, dataset]) + assert composition_model.weights.shape[0] == 1 + assert composition_model.weights.shape[1] == 2 + assert composition_model.output_to_output_index == {"energy": 0} + assert composition_model.atomic_types == [1, 8] + torch.testing.assert_close( + composition_model.weights, torch.tensor([[2.0, 1.0]], dtype=torch.float64) + ) + + +def test_composition_model_predict(): + """Test the prediction of composition energies.""" + + dataset_path = RESOURCES_PATH / "qm9_reduced_100.xyz" + systems = read_systems(dataset_path) + + conf = { + "mtt::U0": { + "quantity": "energy", + "read_from": dataset_path, + "file_format": ".xyz", + "reader": "ase", + "key": "U0", + "unit": "eV", + "forces": False, + "stress": False, + "virial": False, + } + } + targets, target_info = read_targets(OmegaConf.create(conf)) + dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) + + composition_model = CompositionModel( + model_hypers={}, + dataset_info=DatasetInfo( + length_unit="angstrom", + atomic_types=[1, 6, 7, 8], + targets=target_info, + ), + ) + + composition_model.train_model(dataset) + + # per_atom = False + output = composition_model( + systems[:5], + {"mtt::U0": ModelOutput(quantity="energy", unit="", per_atom=False)}, + ) + assert "mtt::U0" in output + assert output["mtt::U0"].block().samples.names == ["system"] + assert output["mtt::U0"].block().values.shape == (5, 1) + + # per_atom = True + output = composition_model( + systems[:5], + {"mtt::U0": ModelOutput(quantity="energy", unit="", per_atom=True)}, + ) + assert "mtt::U0" in output + assert output["mtt::U0"].block().samples.names == ["system", "atom"] + assert output["mtt::U0"].block().values.shape != (5, 1) + + # with selected_atoms + selected_atoms = metatensor.torch.Labels( + names=["system", "atom"], + values=torch.tensor([[0, 0]]), + ) + + output = composition_model( + systems[:5], + {"mtt::U0": ModelOutput(quantity="energy", unit="", per_atom=True)}, + selected_atoms=selected_atoms, + ) + assert "mtt::U0" in output + assert output["mtt::U0"].block().samples.names == ["system", "atom"] + assert output["mtt::U0"].block().values.shape == (1, 1) + + output = composition_model( + systems[:5], + {"mtt::U0": ModelOutput(quantity="energy", unit="", per_atom=False)}, + selected_atoms=selected_atoms, + ) + assert "mtt::U0" in output + assert output["mtt::U0"].block().samples.names == ["system"] + assert output["mtt::U0"].block().values.shape == (1, 1) + + +def test_composition_model_torchscript(tmpdir): + """Test the torchscripting, saving and loading of the composition model.""" + system = System( + positions=torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float64), + types=torch.tensor([8]), + cell=torch.eye(3, dtype=torch.float64), + ) + + composition_model = CompositionModel( + model_hypers={}, + dataset_info=DatasetInfo( + length_unit="angstrom", + atomic_types=[1, 8], + targets=TargetInfoDict( + { + "energy": TargetInfo( + quantity="energy", + per_atom=False, + ) + } + ), + ), + ) + composition_model = torch.jit.script(composition_model) + composition_model( + [system], {"energy": ModelOutput(quantity="energy", unit="", per_atom=False)} + ) + torch.jit.save(composition_model, tmpdir / "composition_model.pt") + composition_model = torch.jit.load(tmpdir / "composition_model.pt") + composition_model( + [system], {"energy": ModelOutput(quantity="energy", unit="", per_atom=False)} + ) + + +def test_remove_composition(): + """Tests the remove_composition function.""" + + dataset_path = RESOURCES_PATH / "qm9_reduced_100.xyz" + systems = read_systems(dataset_path) + + conf = { + "mtt::U0": { + "quantity": "energy", + "read_from": dataset_path, + "file_format": ".xyz", + "reader": "ase", + "key": "U0", + "unit": "eV", + "forces": False, + "stress": False, + "virial": False, + } + } + targets, target_info = read_targets(OmegaConf.create(conf)) + dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) + + composition_model = CompositionModel( + model_hypers={}, + dataset_info=DatasetInfo( + length_unit="angstrom", + atomic_types=[1, 6, 7, 8], + targets=target_info, + ), + ) + composition_model.train_model(dataset) + + # concatenate all targets + targets["mtt::U0"] = metatensor.torch.join(targets["mtt::U0"], axis="samples") + + std_before = targets["mtt::U0"].block().values.std().item() + remove_composition(systems, targets, composition_model) + std_after = targets["mtt::U0"].block().values.std().item() + + # In QM9 the composition contribution is very large: the standard deviation + # of the energies is reduced by a factor of over 100 upon removing the composition + assert std_after < 100.0 * std_before + + +def test_composition_model_missing_types(): + """ + Test the error when there are too many or too types in the dataset + compared to those declared at initialization. + """ + + # Here we use three synthetic structures: + # - O atom, with an energy of 1.0 + # - H2O molecule, with an energy of 5.0 + # - H4O2 molecule, with an energy of 10.0 + # The expected composition weights are 2.0 for H and 1.0 for O. + + systems = [ + System( + positions=torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float64), + types=torch.tensor([8]), + cell=torch.eye(3, dtype=torch.float64), + ), + System( + positions=torch.tensor( + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype=torch.float64 + ), + types=torch.tensor([1, 1, 8]), + cell=torch.eye(3, dtype=torch.float64), + ), + System( + positions=torch.tensor( + [ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 1.0], + [0.0, 1.0, 1.0], + ], + dtype=torch.float64, + ), + types=torch.tensor([1, 1, 8, 1, 1, 8]), + cell=torch.eye(3, dtype=torch.float64), + ), + ] + energies = [1.0, 5.0, 10.0] + energies = [ + TensorMap( + keys=Labels(names=["_"], values=torch.tensor([[0]])), + blocks=[ + TensorBlock( + values=torch.tensor([[e]], dtype=torch.float64), + samples=Labels(names=["system"], values=torch.tensor([[i]])), + components=[], + properties=Labels(names=["energy"], values=torch.tensor([[0]])), + ) + ], + ) + for i, e in enumerate(energies) + ] + dataset = Dataset.from_dict({"system": systems, "energy": energies}) + + composition_model = CompositionModel( + model_hypers={}, + dataset_info=DatasetInfo( + length_unit="angstrom", + atomic_types=[1], + targets=TargetInfoDict( + { + "energy": TargetInfo( + quantity="energy", + per_atom=False, + ) + } + ), + ), + ) + with pytest.raises( + ValueError, + match="unknown atomic types", + ): + composition_model.train_model(dataset) + + composition_model = CompositionModel( + model_hypers={}, + dataset_info=DatasetInfo( + length_unit="angstrom", + atomic_types=[1, 8, 100], + targets=TargetInfoDict( + { + "energy": TargetInfo( + quantity="energy", + per_atom=False, + ) + } + ), + ), + ) + with pytest.warns( + UserWarning, + match="do not contain atomic types", + ): + composition_model.train_model(dataset) + + +def test_composition_model_wrong_target(): + """ + Test the error when a non-energy is fed to the composition model. + """ + + with pytest.raises( + ValueError, + match="only supports energy-like outputs", + ): + CompositionModel( + model_hypers={}, + dataset_info=DatasetInfo( + length_unit="angstrom", + atomic_types=[1], + targets=TargetInfoDict( + { + "energy": TargetInfo( + quantity="FOO", + per_atom=False, + ) + } + ), + ), + )