diff --git a/src/metatensor/models/cli/eval_model.py b/src/metatensor/models/cli/eval_model.py index 1b8839b30..e9c9ab1f6 100644 --- a/src/metatensor/models/cli/eval_model.py +++ b/src/metatensor/models/cli/eval_model.py @@ -56,5 +56,9 @@ def eval_model(model: str, structures: str, output: str = "output.xyz") -> None: loaded_model = load_model(model) structure_list = read_structures(structures) + + # since the second argument is missing, + # this calculates all the available properties: predictions = loaded_model(structure_list) + write_predictions(output, predictions, structure_list) diff --git a/src/metatensor/models/soap_bpnn/model.py b/src/metatensor/models/soap_bpnn/model.py index 1e05ed8e2..aa44116c1 100644 --- a/src/metatensor/models/soap_bpnn/model.py +++ b/src/metatensor/models/soap_bpnn/model.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Dict, List, Optional import metatensor.torch import rascaline.torch @@ -35,29 +35,23 @@ def __init__(self, all_species: List[int], hypers: dict) -> None: # Build a neural network for each species nns_per_species = [] for _ in all_species: - module_list = [ - torch.nn.Linear(hypers["input_size"], hypers["num_neurons_per_layer"]), - torch.nn.SiLU(), - ] + module_list: List[torch.nn.Module] = [] for _ in range(hypers["num_hidden_layers"]): - module_list.append( - torch.nn.Linear( - hypers["num_neurons_per_layer"], hypers["num_neurons_per_layer"] + if len(module_list) == 0: + module_list.append( + torch.nn.Linear( + hypers["input_size"], hypers["num_neurons_per_layer"] + ) ) - ) - module_list.append(torch.nn.SiLU()) - - # If there are no hidden layers, the number of inputs - # for the last layer is the input size - n_inputs_last_layer = ( - hypers["num_neurons_per_layer"] - if hypers["num_hidden_layers"] > 0 - else hypers["input_size"] - ) + else: + module_list.append( + torch.nn.Linear( + hypers["num_neurons_per_layer"], + hypers["num_neurons_per_layer"], + ) + ) + module_list.append(self.activation_function) - module_list.append( - torch.nn.Linear(n_inputs_last_layer, hypers["output_size"]) - ) nns_per_species.append(torch.nn.Sequential(*module_list)) # Create a module dict to store the neural networks @@ -75,9 +69,7 @@ def forward(self, features: TensorMap) -> TensorMap: new_blocks: List[TensorBlock] = [] for species_str, network in self.layers.items(): species = int(species_str) - if species not in present_blocks: - pass # continue is not accepted by torchscript here - else: + if species in present_blocks: block = features.block({"species_center": species}) output_values = network(block.values) new_blocks.append( @@ -92,6 +84,48 @@ def forward(self, features: TensorMap) -> TensorMap: return TensorMap(keys=features.keys, blocks=new_blocks) +class LinearMap(torch.nn.Module): + def __init__(self, all_species: List[int], n_inputs: int) -> None: + super().__init__() + + # Build a neural network for each species + layer_per_species = [] + for _ in all_species: + layer_per_species.append(torch.nn.Linear(n_inputs, 1)) + + # Create a module dict to store the neural networks + self.layers = torch.nn.ModuleDict( + { + str(species): layer + for species, layer in zip(all_species, layer_per_species) + } + ) + + def forward(self, features: TensorMap) -> TensorMap: + # Create a list of the blocks that are present in the features: + present_blocks = [ + int(features.keys.entry(i).values.item()) + for i in range(features.keys.values.shape[0]) + ] + + new_blocks: List[TensorBlock] = [] + for species_str, layer in self.layers.items(): + species = int(species_str) + if species in present_blocks: + block = features.block({"species_center": species}) + output_values = layer(block.values) + new_blocks.append( + TensorBlock( + values=output_values, + samples=block.samples, + components=block.components, + properties=Labels.single(), + ) + ) + + return TensorMap(keys=features.keys, blocks=new_blocks) + + class Model(torch.nn.Module): def __init__( self, capabilities: ModelCapabilities, hypers: Dict = DEFAULT_MODEL_HYPERS @@ -100,16 +134,18 @@ def __init__( self.name = ARCHITECTURE_NAME # Check capabilities - if len(capabilities.outputs) > 1: - raise ValueError( - "SOAP-BPNN only supports a single output, " - "but multiple outputs were provided" - ) - if next(iter(capabilities.outputs.values())).quantity != "energy": - raise ValueError( - "SOAP-BPNN only supports energy-like outputs, " - f"but {next(iter(capabilities.outputs.values())).quantity} was provided" - ) + for output in capabilities.outputs.values(): + if output.quantity != "energy": + raise ValueError( + "SOAP-BPNN only supports energy-like outputs, " + f"but a {next(iter(capabilities.outputs.values())).quantity} " + "was provided" + ) + if output.per_atom: + raise ValueError( + "SOAP-BPNN only supports per-structure outputs, " + "but a per-atom output was provided" + ) self.capabilities = capabilities self.all_species = capabilities.species @@ -118,9 +154,16 @@ def __init__( # creates a composition weight tensor that can be directly indexed by species, # this can be left as a tensor of zero or set from the outside using # set_composition_weights (recommended for better accuracy) + n_outputs = len(capabilities.outputs) self.register_buffer( - "composition_weights", torch.zeros(max(self.all_species) + 1) + "composition_weights", torch.zeros((n_outputs, max(self.all_species) + 1)) ) + # buffers cannot be indexed by strings (torchscript), so we create a single + # tensor for all output. Due to this, we need to slice the tensor when we use + # it and use the output name to select the correct slice via a dictionary + self.output_to_index = { + output_name: i for i, output_name in enumerate(capabilities.outputs.keys()) + } self.soap_calculator = rascaline.torch.SoapPowerSpectrum(**hypers["soap"]) hypers_bpnn = hypers["bpnn"] @@ -129,7 +172,7 @@ def __init__( * hypers["soap"]["max_radial"] ** 2 * (hypers["soap"]["max_angular"] + 1) ) - hypers_bpnn["output_size"] = 1 + self.bpnn = MLPMap(self.all_species, hypers_bpnn) self.neighbor_species_1_labels = Labels( names=["species_neighbor_1"], @@ -140,7 +183,31 @@ def __init__( values=torch.tensor(self.all_species).reshape(-1, 1), ) - def forward(self, systems: List[System]) -> Dict[str, TensorMap]: + if hypers_bpnn["num_hidden_layers"] == 0: + n_inputs_last_layer = hypers_bpnn["input_size"] + else: + n_inputs_last_layer = hypers_bpnn["num_neurons_per_layer"] + + self.last_layers = torch.nn.ModuleDict( + { + output_name: LinearMap(self.all_species, n_inputs_last_layer) + for output_name in capabilities.outputs.keys() + } + ) + + def forward( + self, systems: List[System], requested_outputs: Optional[List[str]] = None + ) -> Dict[str, TensorMap]: + if requested_outputs is None: # default to all outputs + requested_outputs = list(self.capabilities.outputs.keys()) + + for requested_output in requested_outputs: + if requested_output not in self.capabilities.outputs.keys(): + raise ValueError( + f"Requested output {requested_output} is not within " + "the model's capabilities." + ) + soap_features = self.soap_calculator(systems) device = soap_features.block(0).values.device @@ -151,19 +218,39 @@ def forward(self, systems: List[System]) -> Dict[str, TensorMap]: self.neighbor_species_2_labels.to(device) ) - atomic_energies = self.bpnn(soap_features) - atomic_energies = apply_composition_contribution( - atomic_energies, self.composition_weights - ) - atomic_energies = atomic_energies.keys_to_samples("species_center") + hidden_features = self.bpnn(soap_features) + + atomic_energies: Dict[str, metatensor.torch.TensorMap] = {} + for output_name, output_layer in self.last_layers.items(): + if output_name in requested_outputs: + atomic_energies[output_name] = apply_composition_contribution( + output_layer(hidden_features), + self.composition_weights[self.output_to_index[output_name]], + ) # Sum the atomic energies coming from the BPNN to get the total energy - total_energies = metatensor.torch.sum_over_samples( - atomic_energies, ["center", "species_center"] - ) + total_energies: Dict[str, metatensor.torch.TensorMap] = {} + for output_name, atomic_energy in atomic_energies.items(): + atomic_energy = atomic_energy.keys_to_samples("species_center") + total_energies[output_name] = metatensor.torch.sum_over_samples( + atomic_energy, ["center", "species_center"] + ) + # Change the energy label from _ to (0, 1): + total_energies[output_name] = metatensor.torch.TensorMap( + keys=Labels( + names=["lambda", "sigma"], + values=torch.tensor([[0, 1]]), + ), + blocks=[total_energies[output_name].block()], + ) - return {"energy": total_energies} + return total_energies - def set_composition_weights(self, input_composition_weights: torch.Tensor) -> None: + def set_composition_weights( + self, output_name: str, input_composition_weights: torch.Tensor + ) -> None: + """Set the composition weights for a given output.""" # all species that are not present retain their weight of zero - self.composition_weights[self.all_species] = input_composition_weights + self.composition_weights[self.output_to_index[output_name]][ + self.all_species + ] = input_composition_weights diff --git a/src/metatensor/models/soap_bpnn/tests/test_regression.py b/src/metatensor/models/soap_bpnn/tests/test_regression.py index e78f8b697..5e3ca1493 100644 --- a/src/metatensor/models/soap_bpnn/tests/test_regression.py +++ b/src/metatensor/models/soap_bpnn/tests/test_regression.py @@ -35,12 +35,12 @@ def test_regression_init(): [rascaline.torch.systems_to_torch(structure) for structure in structures] ) expected_output = torch.tensor( - [ - [[0.5021], [0.3809], [0.1849], [0.2126], [0.0920]], - ], + [[-0.4615], [-0.4367], [-0.3004], [-0.2606], [-0.2380]], dtype=torch.float64, ) + print(output["energy"].block().values) + assert torch.allclose(output["energy"].block().values, expected_output, rtol=1e-3) @@ -61,8 +61,10 @@ def test_regression_train(): output = soap_bpnn(structures[:5]) expected_output = torch.tensor( - [[-40.5923], [-56.5135], [-76.4457], [-77.2500], [-93.1583]], + [[-39.9658], [-56.0888], [-76.1100], [-76.9461], [-93.0914]], dtype=torch.float64, ) - assert torch.allclose(output["energy"].block().values, expected_output, rtol=1e-3) + print(output["U0"].block().values) + + assert torch.allclose(output["U0"].block().values, expected_output, rtol=1e-3) diff --git a/src/metatensor/models/soap_bpnn/train.py b/src/metatensor/models/soap_bpnn/train.py index 4720a4186..6d1b7c0cf 100644 --- a/src/metatensor/models/soap_bpnn/train.py +++ b/src/metatensor/models/soap_bpnn/train.py @@ -5,7 +5,9 @@ from metatensor.torch.atomistic import ModelCapabilities, ModelOutput from ..utils.composition import calculate_composition_weights +from ..utils.compute_loss import compute_model_loss from ..utils.data import collate_fn +from ..utils.loss import TensorMapDictLoss from ..utils.model_io import save_model from .model import DEFAULT_HYPERS, Model @@ -13,17 +15,21 @@ logger = logging.getLogger(__name__) -def loss_function(predicted, target): - return torch.sum((predicted.block().values - target.block().values) ** 2) - - def train(train_dataset, hypers=DEFAULT_HYPERS, output_dir="."): + if len(train_dataset.targets) > 1: + raise ValueError( + f"`train_dataset` contains {len(train_dataset.targets)} targets but we " + "currently only support a single target value!" + ) + else: + target_name = list(train_dataset.targets.keys())[0] + # Set the model's capabilities: model_capabilities = ModelCapabilities( length_unit="Angstrom", species=train_dataset.all_species, outputs={ - "U0": ModelOutput( + target_name: ModelOutput( quantity="energy", unit="eV", ) @@ -36,17 +42,9 @@ def train(train_dataset, hypers=DEFAULT_HYPERS, output_dir="."): hypers=hypers["model"], ) - if len(train_dataset.targets) > 1: - raise ValueError( - f"`train_dataset` contains {len(train_dataset.targets)} targets but we " - "currently only support a single target value!" - ) - else: - target = list(train_dataset.targets.keys())[0] - # Calculate and set the composition weights: - composition_weights = calculate_composition_weights(train_dataset, target) - model.set_composition_weights(composition_weights) + composition_weights = calculate_composition_weights(train_dataset, target_name) + model.set_composition_weights(target_name, composition_weights) hypers_training = hypers["training"] @@ -58,6 +56,11 @@ def train(train_dataset, hypers=DEFAULT_HYPERS, output_dir="."): collate_fn=collate_fn, ) + # Create a loss function: + loss_fn = TensorMapDictLoss( + {target_name: {"values": 1.0}}, + ) + # Create an optimizer: optimizer = torch.optim.Adam( model.parameters(), lr=hypers_training["learning_rate"] @@ -75,8 +78,7 @@ def train(train_dataset, hypers=DEFAULT_HYPERS, output_dir="."): for batch in train_dataloader: optimizer.zero_grad() structures, targets = batch - predicted = model(structures) - loss = loss_function(predicted["energy"], targets["U0"]) + loss = compute_model_loss(loss_fn, model, structures, targets) loss.backward() optimizer.step() diff --git a/src/metatensor/models/utils/compute_loss.py b/src/metatensor/models/utils/compute_loss.py new file mode 100644 index 000000000..2e64ff37a --- /dev/null +++ b/src/metatensor/models/utils/compute_loss.py @@ -0,0 +1,210 @@ +from typing import Dict, List + +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.atomistic import System + +from .loss import TensorMapDictLoss +from .output_gradient import compute_gradient + + +def compute_model_loss( + loss: TensorMapDictLoss, + model: torch.nn.Module, + systems: List[System], + targets: Dict[str, TensorMap], +): + """ + Compute the loss of a model on a set of targets. + + :param loss: The loss function to use. + :param model: The model to use. + :param systems: The systems to use. + :param targets: The targets to use. + + :returns: The loss as a scalar `torch.Tensor`. + """ + # Assert that all targets are within the model's capabilities: + if not set(targets.keys()).issubset(model.capabilities.outputs.keys()): + raise ValueError("Not all targets are within the model's capabilities.") + + # Find if there are any energy targets that require gradients: + energy_targets = [] + energy_targets_that_require_position_gradients = [] + energy_targets_that_require_displacement_gradients = [] + for target_name in targets.keys(): + # Check if the target is an energy: + if model.capabilities.outputs[target_name].quantity == "energy": + energy_targets.append(target_name) + # Check if the energy requires gradients: + if targets[target_name].block().has_gradient("positions"): + energy_targets_that_require_position_gradients.append(target_name) + if targets[target_name].block().has_gradient("displacements"): + energy_targets_that_require_displacement_gradients.append(target_name) + + if len(energy_targets_that_require_displacement_gradients) > 0: + # TODO: raise an error if the systems do not have a cell + # if not all([system.has_cell for system in systems]): + # raise ValueError("One or more systems does not have a cell.") + displacements = [ + torch.eye( + 3, + requires_grad=True, + dtype=system.cell.dtype, + device=system.cell.device, + ) + for system in systems + ] + # Create new "displaced" systems: + systems = [ + System( + positions=system.positions @ displacement, + cell=system.cell @ displacement, + species=system.species, + ) + for system, displacement in zip(systems, displacements) + ] + else: + if len(energy_targets_that_require_position_gradients) > 0: + # Set positions to require gradients: + for system in systems: + system.positions.requires_grad_(True) + + # Based on the keys of the targets, get the outputs of the model: + model_outputs = model(systems, targets.keys()) + + for energy_target in energy_targets: + # If the energy target requires gradients, compute them: + target_requires_pos_gradients = ( + energy_target in energy_targets_that_require_position_gradients + ) + target_requires_disp_gradients = ( + energy_target in energy_targets_that_require_displacement_gradients + ) + if target_requires_pos_gradients and target_requires_disp_gradients: + gradients = compute_gradient( + model_outputs[energy_target].block().values, + [system.positions for system in systems] + displacements, + is_training=True, + ) + old_energy_tensor_map = model_outputs[energy_target] + new_block = old_energy_tensor_map.block().copy() + new_block.add_gradient( + "positions", _position_gradients_to_block(gradients[: len(systems)]) + ) + new_block.add_gradient( + "displacements", + _displacement_gradients_to_block(gradients[len(systems) :]), + ) + new_energy_tensor_map = TensorMap( + keys=old_energy_tensor_map.keys, + blocks=[new_block], + ) + model_outputs[energy_target] = new_energy_tensor_map + elif target_requires_pos_gradients: + gradients = compute_gradient( + model_outputs[energy_target].block().values, + [system.positions for system in systems], + is_training=True, + ) + old_energy_tensor_map = model_outputs[energy_target] + new_block = old_energy_tensor_map.block().copy() + new_block.add_gradient("positions", _position_gradients_to_block(gradients)) + new_energy_tensor_map = TensorMap( + keys=old_energy_tensor_map.keys, + blocks=[new_block], + ) + model_outputs[energy_target] = new_energy_tensor_map + elif target_requires_disp_gradients: + gradients = compute_gradient( + model_outputs[energy_target].block().values, + displacements, + is_training=True, + ) + old_energy_tensor_map = model_outputs[energy_target] + new_block = old_energy_tensor_map.block().copy() + new_block.add_gradient( + "displacements", _displacement_gradients_to_block(gradients) + ) + new_energy_tensor_map = TensorMap( + keys=old_energy_tensor_map.keys, + blocks=[new_block], + ) + model_outputs[energy_target] = new_energy_tensor_map + else: + pass + + # Compute the loss: + return loss(model_outputs, targets) + + +def _position_gradients_to_block(gradients_list): + """Convert a list of position gradients to a `TensorBlock` + which can act as a gradient block to an energy block.""" + + # `gradients` consists of a list of tensors where the second dimension is 3 + gradients = torch.concatenate(gradients_list, dim=0).unsqueeze(-1) + # unsqueeze for the property dimension + + samples = Labels( + names=["sample", "atom"], + values=torch.stack( + [ + torch.concatenate( + [ + torch.tensor([i] * len(structure)) + for i, structure in enumerate(gradients_list) + ] + ), + torch.concatenate( + [torch.arange(len(structure)) for structure in gradients_list] + ), + ], + dim=1, + ), + ) + + components = [ + Labels( + names=["coordinate"], + values=torch.tensor([[0], [1], [2]]), + ) + ] + + return TensorBlock( + values=gradients, + samples=samples, + components=components, + properties=Labels.single(), + ) + + +def _displacement_gradients_to_block(gradients_list): + """Convert a list of displacement gradients to a `TensorBlock` + which can act as a gradient block to an energy block.""" + + # `gradients` consists of a list of tensors where the second dimension is 3 + gradients = torch.concatenate(gradients_list, dim=0).unsqueeze(-1) + # unsqueeze for the property dimension + + samples = Labels( + names=["sample"], values=torch.arange(len(gradients_list)).unsqueeze(-1) + ) + + components = [ + Labels( + names=["cell vector"], + values=torch.tensor([[0], [1], [2]]), + ), + Labels( + names=["coordinate"], + values=torch.tensor([[0], [1], [2]]), + ), + ] + + return TensorBlock( + values=gradients, + samples=samples, + components=components, + properties=Labels.single(), + ) diff --git a/src/metatensor/models/utils/data/readers/targets/ase.py b/src/metatensor/models/utils/data/readers/targets/ase.py index 0f3c8a5d2..88e593cb5 100644 --- a/src/metatensor/models/utils/data/readers/targets/ase.py +++ b/src/metatensor/models/utils/data/readers/targets/ase.py @@ -33,9 +33,11 @@ def read_ase( values=torch.tensor(values).reshape(-1, 1), samples=Labels(["structure"], torch.arange(n_structures).reshape(-1, 1)), components=[], - properties=Labels(["energy"], torch.tensor([(0,)])), + properties=Labels.single(), ) - target_dictionary[target_value] = TensorMap(Labels.single(), [block]) + target_dictionary[target_value] = TensorMap( + keys=Labels(["lambda", "sigma"], torch.tensor([(0, 1)])), blocks=[block] + ) return target_dictionary diff --git a/src/metatensor/models/utils/data/writers/xyz.py b/src/metatensor/models/utils/data/writers/xyz.py index 7f5d11819..0c65e7b06 100644 --- a/src/metatensor/models/utils/data/writers/xyz.py +++ b/src/metatensor/models/utils/data/writers/xyz.py @@ -14,9 +14,14 @@ def write_xyz(filename: str, predictions: TensorMap, structures: List[System]) - :param predictions: prediction values written to the file. :param structures: strcutures additional written to the file. """ + # Get the target property name: + target_name = next(iter(predictions.keys())) + frames = [] for i_system, system in enumerate(structures): - info = {"energy": float(predictions["energy"].block().values[i_system, 0])} + info = { + target_name: float(predictions[target_name].block().values[i_system, 0]) + } atoms = ase.Atoms(symbols=system.species, positions=system.positions, info=info) if torch.any(system.cell != 0): diff --git a/src/metatensor/models/utils/loss.py b/src/metatensor/models/utils/loss.py new file mode 100644 index 000000000..3809dbf41 --- /dev/null +++ b/src/metatensor/models/utils/loss.py @@ -0,0 +1,157 @@ +from typing import Dict, Optional + +import torch +from metatensor.torch import TensorMap + + +# This file defines losses for metatensor models. + + +class TensorMapLoss: + """ + A loss function that operates on two `metatensor.torch.TensorMap`s. + + The loss is computed as the sum of the loss on the block values and + the loss on the gradients, with weights specified at initialization. + + At the moment, this loss function assumes that all the gradients + declared at initialization are present in both TensorMaps. + + :param reduction: The reduction to apply to the loss. See `torch.nn.MSELoss`. + :param weight: The weight to apply to the loss on the block values. + :param gradient_weights: The weights to apply to the loss on the gradients. + + :returns: The loss as a scalar `torch.Tensor`. + """ + + def __init__( + self, + reduction: str = "mean", + weight: float = 1.0, + gradient_weights: Optional[Dict[str, float]] = None, + ): + self.loss = torch.nn.MSELoss(reduction=reduction) + self.weight = weight + self.gradient_weights = {} if gradient_weights is None else gradient_weights + + def __call__( + self, tensor_map_1: TensorMap, tensor_map_2: TensorMap + ) -> torch.Tensor: + # Check that the two have the same metadata, except for the samples, + # which can be different due to batching, but must have the same size: + if tensor_map_1.keys != tensor_map_2.keys: + raise ValueError( + "TensorMapLoss requires the two TensorMaps to have the same keys." + ) + if tensor_map_1.block().properties != tensor_map_2.block().properties: + raise ValueError( + "TensorMapLoss requires the two TensorMaps to have the same properties." + ) + if tensor_map_1.block().components != tensor_map_2.block().components: + raise ValueError( + "TensorMapLoss requires the two TensorMaps to have the same components." + ) + if len(tensor_map_1.block().samples) != len(tensor_map_2.block().samples): + raise ValueError( + "TensorMapLoss requires the two TensorMaps " + "to have the same number of samples." + ) + for gradient_name in self.gradient_weights.keys(): + if len(tensor_map_1.block().gradient(gradient_name).samples) != len( + tensor_map_2.block().gradient(gradient_name).samples + ): + raise ValueError( + "TensorMapLoss requires the two TensorMaps " + "to have the same number of gradient samples." + ) + if ( + tensor_map_1.block().gradient(gradient_name).properties + != tensor_map_2.block().gradient(gradient_name).properties + ): + raise ValueError( + "TensorMapLoss requires the two TensorMaps " + "to have the same gradient properties." + ) + if ( + tensor_map_1.block().gradient(gradient_name).components + != tensor_map_2.block().gradient(gradient_name).components + ): + raise ValueError( + "TensorMapLoss requires the two TensorMaps " + "to have the same gradient components." + ) + + # If the two TensorMaps have different symmetry keys: + if len(tensor_map_1) != 1: + raise NotImplementedError( + "TensorMapLoss does not yet support multiple symmetry keys." + ) + + # Compute the loss: + loss = torch.zeros( + (), + dtype=tensor_map_1.block().values.dtype, + device=tensor_map_1.block().values.device, + ) + loss += self.weight * self.loss( + tensor_map_1.block().values, tensor_map_2.block().values + ) + for gradient_name, gradient_weight in self.gradient_weights.items(): + loss += gradient_weight * self.loss( + tensor_map_1.block().gradient(gradient_name).values, + tensor_map_2.block().gradient(gradient_name).values, + ) + + return loss + + +class TensorMapDictLoss: + """ + A loss function that operates on two `Dict[str, metatensor.torch.TensorMap]`. + + At initialization, the user specifies a list of keys to use for the loss, + along with a weight for each key (as well as gradient weights). + + The loss is then computed as a weighted sum. Any keys that are not present + in the dictionaries are ignored. + + :param weights: A dictionary mapping keys to weights. Each weight is itself + a dictionary mapping "values" to the weight to apply to the loss on the + block values, and gradient names to the weights to apply to the loss on + the gradients. + :param reduction: The reduction to apply to the loss. See `torch.nn.MSELoss`. + + :returns: The loss as a scalar `torch.Tensor`. + """ + + def __init__( + self, + weights: Dict[str, Dict[str, float]], + reduction: str = "mean", + ): + self.losses = {} + for key, weight in weights.items(): + # Remove the value weight from the gradient weights and store it separately: + value_weight = weight.pop("values") + # Define the loss relative to this key: + self.losses[key] = TensorMapLoss( + reduction=reduction, weight=value_weight, gradient_weights=weight + ) + + def __call__( + self, + tensor_map_dict_1: Dict[str, TensorMap], + tensor_map_dict_2: Dict[str, TensorMap], + ) -> torch.Tensor: + # Assert that the two have the keys: + assert set(tensor_map_dict_1.keys()) == set(tensor_map_dict_2.keys()) + + # Initialize the loss: + first_values = next(iter(tensor_map_dict_1.values())).block(0).values + loss = torch.zeros((), dtype=first_values.dtype, device=first_values.device) + + # Compute the loss: + for key in tensor_map_dict_1.keys(): + loss += self.losses[key](tensor_map_dict_1[key], tensor_map_dict_2[key]) + + return loss diff --git a/src/metatensor/models/utils/output_gradient.py b/src/metatensor/models/utils/output_gradient.py new file mode 100644 index 000000000..dda6888d2 --- /dev/null +++ b/src/metatensor/models/utils/output_gradient.py @@ -0,0 +1,31 @@ +from typing import List, Optional + +import torch + + +def compute_gradient( + target: torch.Tensor, inputs: List[torch.Tensor], is_training: bool +) -> List[torch.Tensor]: + """ + Calculates the gradient of a target tensor with respect to a list of input tensors. + + ``target`` must be a single torch.Tensor object. If target contains multiple values, + the gradient will be calculated with respect to the sum of all values. + """ + + grad_outputs: Optional[List[Optional[torch.Tensor]]] = [torch.ones_like(target)] + gradient = torch.autograd.grad( + outputs=[target], + inputs=inputs, + grad_outputs=grad_outputs, + retain_graph=is_training, + create_graph=is_training, + ) + if gradient is None: + raise ValueError( + "Unexpected None value for computed gradient. " + "One or more operations inside the model might " + "not have a gradient implementation." + ) + else: + return gradient diff --git a/tests/cli/test_eval_model.py b/tests/cli/test_eval_model.py index 183ab24f3..004b22292 100644 --- a/tests/cli/test_eval_model.py +++ b/tests/cli/test_eval_model.py @@ -26,4 +26,4 @@ def test_eval(output, monkeypatch, tmp_path): subprocess.check_call(command) frames = ase.io.read(output, ":") - frames[0].info["energy"] + frames[0].info["U0"] diff --git a/tests/resources/alchemical_reduced_10.xyz b/tests/resources/alchemical_reduced_10.xyz new file mode 100644 index 000000000..b7820449e --- /dev/null +++ b/tests/resources/alchemical_reduced_10.xyz @@ -0,0 +1,416 @@ +36 +Lattice="10.684835525 0.0 0.0 0.0 10.684835525 0.0 0.0 0.0 7.123223683" Properties=species:S:1:pos:R:3:forces:R:3 class=3 scale=1.0625018098857693 fps_idx=-1 name=3709 energy=-274.79272153 stress="0.08617869478775508 -0.004215009701382356 0.001135424132635053 -0.004215009701382356 0.08703453676060532 -5.7091083974454166e-05 0.001135424132635053 -5.7091083974454166e-05 0.08197792821019799" free_energy=-274.80386702 pbc="T T T" +Sc 8.90403000 5.34242000 5.34242000 1.84954700 -0.60305200 0.01523300 +Sc 1.78081000 5.34242000 1.78081000 -0.49242500 -0.74499900 -0.27740800 +Sc 3.56161000 7.12322000 3.56161000 1.15562000 -0.00043800 -0.06187800 +Sc 0.00000000 0.00000000 3.56161000 -0.85511700 0.01901000 -0.67265000 +Sc 8.90403000 1.78081000 5.34242000 -0.27434500 0.20867500 -0.03951900 +Sc 0.00000000 0.00000000 0.00000000 -1.18264700 -0.54330300 0.62220800 +Sc 3.56161000 0.00000000 3.56161000 1.25239000 -1.18703600 -0.40478100 +V 3.56161000 7.12322000 0.00000000 1.31412200 -0.52462800 -0.15144200 +V 8.90403000 5.34242000 1.78081000 1.06924400 -0.60758500 0.26851100 +V 5.34242000 8.90403000 5.34242000 0.22253700 0.72659900 0.34632000 +V 0.00000000 7.12322000 0.00000000 -0.55686700 -0.08970400 0.17212400 +V 7.12322000 0.00000000 0.00000000 0.48191100 -1.30941400 0.48553700 +V 5.34242000 5.34242000 1.78081000 -0.39142100 -0.80700800 -0.70213500 +Cr 7.12322000 0.00000000 3.56161000 0.01381200 -1.05781000 -0.46442800 +Cr 0.00000000 7.12322000 3.56161000 -0.45647900 -0.34033600 -0.09011900 +Cr 5.34242000 8.90403000 1.78081000 0.10052000 0.58999000 -0.24462500 +Cr 0.00000000 3.56161000 3.56161000 -0.23056400 0.13935000 -0.52112400 +Cu 3.56161000 3.56161000 3.56161000 -0.00713800 0.27627400 -0.08896200 +Cu 3.56161000 0.00000000 0.00000000 0.01210400 -0.33924500 0.13192400 +Cu 0.00000000 3.56161000 0.00000000 -0.18990800 0.01517500 0.19734500 +Y 7.12322000 7.12322000 0.00000000 -0.72846500 0.33287000 1.25593600 +Y 1.78081000 8.90403000 5.34242000 0.03014400 -1.22317300 0.72989600 +Y 1.78081000 8.90403000 1.78081000 -0.05851900 -0.36160900 -0.71290700 +Y 1.78081000 1.78081000 5.34242000 0.32779200 1.96897500 0.33679600 +Y 5.34242000 1.78081000 5.34242000 -0.21380300 -0.15575200 0.75537500 +Y 7.12322000 7.12322000 3.56161000 -0.58523000 0.43246900 -1.23375900 +Y 5.34242000 1.78081000 1.78081000 -0.31247600 -0.10155400 -0.65566900 +Nb 7.12322000 3.56161000 0.00000000 0.02211900 1.29882200 0.92695800 +Nb 1.78081000 5.34242000 5.34242000 -0.77269800 0.31461400 0.20306300 +Nb 8.90403000 8.90403000 1.78081000 0.33117200 0.32551100 0.14628400 +Hf 5.34242000 5.34242000 5.34242000 -1.57041900 -1.08461800 0.84930100 +Hf 7.12322000 3.56161000 3.56161000 0.75825900 1.23749500 -0.92562000 +Hf 8.90403000 1.78081000 1.78081000 -0.17586500 0.38581400 0.05539300 +Hf 1.78081000 1.78081000 1.78081000 -0.00032300 1.64333700 -0.13970500 +Hf 8.90403000 8.90403000 5.34242000 0.26984900 0.53720200 -0.20491000 +Pt 3.56161000 3.56161000 0.00000000 -0.15643700 0.62908400 0.09343600 +36 +Lattice="9.635884875 0.0 0.0 0.0 9.635884875 0.0 0.0 0.0 6.42392325" Properties=species:S:1:pos:R:3:forces:R:3 class=3 scale=0.9434298026410756 fps_idx=-1 name=2996 energy=-286.87287766 stress="-0.09254085228504165 0.0024403551701108433 -0.006314228948708925 0.0024403551701108433 -0.09549506465997734 0.006802726901955303 -0.006314228948708925 0.006802726901955303 -0.09694182775082971" free_energy=-286.87390508 pbc="T T T" +Sc 0.02921000 3.07534000 3.30779000 0.35901200 2.64848100 -1.17640100 +Sc 4.85256000 1.64607000 1.60982000 -2.32323000 -0.19597400 -1.69855600 +Sc 0.08519000 6.27742000 6.41241000 0.13125800 1.30246100 -0.06136400 +Sc 8.04074000 8.08267000 1.57201000 -0.17883000 -0.39156300 -0.89605200 +Sc 4.89356000 4.93023000 4.84477000 -2.50677700 -3.40979100 -0.03795700 +Sc 6.49174000 0.08619000 3.26806000 -0.68763900 -1.05644400 0.01069600 +Sc 1.69208000 8.06353000 4.75564000 1.52226400 0.64117800 -0.79877100 +Sc 7.99604000 1.57451000 1.65486000 -1.07168300 0.59373800 -2.22462900 +Sc 4.83462000 1.44380000 4.84776000 -2.73325400 1.13088300 1.00102800 +Sc 9.53357000 0.07041000 3.12542000 1.38976400 -1.61965900 -0.06207900 +Sc 3.12921000 6.48182000 0.02447000 -1.36169400 -0.53665100 -0.03391100 +Sc 1.57723000 4.85359000 1.60382000 2.08439900 -0.59142800 0.60903000 +Co 3.17493000 6.45886000 3.26211000 0.18155600 -0.32115400 -0.21299300 +Co 1.51310000 5.00263000 4.75398000 0.82406900 -0.77435300 -0.27676800 +Co 6.47202000 3.22036000 0.02547000 0.03286900 0.27779900 0.14938100 +Co 3.21609000 3.38336000 6.36089000 -0.19411100 -0.05843600 0.37146400 +Co 3.25525000 0.11780000 3.09212000 -0.47316600 -0.22661100 -0.18194700 +Co 3.23903000 9.50952000 0.20074000 -0.09490600 -0.11672600 -0.35531800 +Co 1.68911000 7.99308000 1.49061000 -0.21061600 0.11067300 0.39802700 +Zr 0.09797000 0.15428000 6.38258000 0.19629500 -2.62187900 0.81038500 +Zr 1.42935000 1.64884000 4.78282000 5.08297200 -0.42159700 -0.67851500 +Zr 4.87702000 7.85368000 4.88279000 -2.33524600 2.82148000 -1.68347500 +Zr 0.05230000 3.19811000 6.40585000 -0.10496600 2.39935600 0.16638200 +Zr 8.19241000 4.81908000 4.73051000 -1.22868000 -1.08007200 0.97149900 +Zr 4.89754000 4.94110000 1.69964000 -2.24906500 -2.61724200 -0.33636600 +Zr 6.44153000 3.06623000 3.13056000 0.30210300 1.80915000 0.44235900 +Zr 0.04720000 6.37272000 3.11469000 1.26009500 0.82548700 -0.07187300 +Zr 1.55134000 1.72151000 1.48906000 4.53941500 -1.14875400 1.00802700 +Zr 6.33364000 6.39266000 6.39151000 2.11250700 -0.33827800 1.11585600 +Zr 4.80528000 7.88029000 1.66955000 -2.10128600 2.79408900 1.04210700 +Zr 8.04114000 1.70927000 4.91119000 -2.67768800 -0.32941700 1.54041300 +Zr 7.96040000 8.02735000 4.86003000 0.43958700 0.76381200 0.29915000 +Zr 6.53648000 6.25756000 3.28220000 1.16587500 0.36679100 -1.06470800 +W 6.25718000 9.54549000 6.42115000 1.32460200 -0.11885800 0.59104000 +W 3.08815000 3.30210000 3.20604000 -0.19788700 0.39967900 1.14255700 +W 7.98093000 4.72271000 1.48848000 -0.21792000 -0.91017100 0.18228000 +36 +Lattice="9.770912168 0.0 0.0 0.0 9.770912168 0.0 0.0 0.0 6.513941446" Properties=species:S:1:pos:R:3:forces:R:3 class=3 scale=1.0247182617379567 fps_idx=-1 name=712 energy=-308.69694165 stress="0.12166491975314685 0.0068899332674621444 0.002342869039073673 0.0068899332674621444 0.1225653697917198 -0.010675776801348767 0.002342869039073673 -0.010675776801348767 0.12271392395042494" free_energy=-308.70381995 pbc="T T T" +Ti 6.51394000 0.00000000 3.25697000 0.09275800 0.84328500 -0.66454200 +Ti 4.88546000 1.62849000 4.88546000 -0.52071500 0.36620700 -0.52906100 +Ti 1.62849000 1.62849000 1.62849000 0.66985700 -0.47792200 -0.32320600 +Ti 0.00000000 0.00000000 3.25697000 0.16179500 -0.40246200 0.04161600 +Ti 4.88546000 4.88546000 1.62849000 -1.01655000 0.69151000 1.14579700 +Ti 0.00000000 3.25697000 3.25697000 0.35292800 0.35056800 0.45860700 +Ti 3.25697000 0.00000000 3.25697000 -0.45219000 -0.50469500 -1.15525000 +Fe 8.14243000 8.14243000 4.88546000 -0.80646300 0.27182000 -0.39593000 +Fe 6.51394000 3.25697000 3.25697000 0.11288900 -0.22448500 0.65072700 +Fe 4.88546000 4.88546000 4.88546000 -0.37082800 0.25717000 -0.02599500 +Fe 4.88546000 8.14243000 1.62849000 0.36224300 -0.09630100 -0.37213600 +Fe 3.25697000 6.51394000 3.25697000 -0.51768100 -0.22416700 -0.17063000 +Fe 8.14243000 1.62849000 1.62849000 -0.22613300 -0.78889400 0.59572600 +Fe 0.00000000 6.51394000 3.25697000 0.26811400 -0.07800500 0.64069500 +Fe 1.62849000 4.88546000 1.62849000 1.31014300 -0.09394800 0.38258100 +Fe 1.62849000 4.88546000 4.88546000 0.92867200 -0.00729100 -0.49336400 +Fe 4.88546000 1.62849000 1.62849000 -0.65799200 -0.18760500 -0.06113800 +Fe 0.00000000 0.00000000 0.00000000 0.48076800 0.59355400 -0.80716200 +Fe 1.62849000 8.14243000 1.62849000 0.85765700 0.14208300 0.18653800 +Mo 8.14243000 4.88546000 4.88546000 -1.03385400 1.01722400 -2.24430400 +Mo 3.25697000 6.51394000 0.00000000 -1.16089800 -0.41476300 0.43885600 +Mo 6.51394000 6.51394000 0.00000000 -0.72181900 -0.33830700 -0.92411800 +Mo 3.25697000 3.25697000 3.25697000 -0.07263100 0.65746300 -0.05804900 +Mo 6.51394000 6.51394000 3.25697000 -0.48713700 -0.59295300 0.91162500 +Mo 1.62849000 8.14243000 4.88546000 0.54132200 -0.22979500 0.00456800 +Mo 3.25697000 3.25697000 0.00000000 -0.17924100 0.40017100 -0.06557500 +Mo 3.25697000 0.00000000 0.00000000 -0.61518100 -0.10949000 0.87985500 +Lu 4.88546000 8.14243000 4.88546000 -0.59112800 -1.31184500 -0.33813800 +Lu 0.00000000 3.25697000 0.00000000 0.96312300 0.20013900 0.14848500 +Lu 8.14243000 4.88546000 1.62849000 -0.76560300 0.58052700 2.36557600 +Lu 0.00000000 6.51394000 0.00000000 2.05860500 0.65531400 -1.07496700 +Lu 8.14243000 8.14243000 1.62849000 0.71240300 0.19100800 0.68231900 +Lu 6.51394000 3.25697000 0.00000000 -1.23626000 -0.65941500 0.36476200 +Ta 8.14243000 1.62849000 4.88546000 -0.02728300 -0.97718700 -1.09596200 +Ta 1.62849000 1.62849000 4.88546000 1.23017800 -0.48478000 -0.00933500 +Ta 6.51394000 0.00000000 0.00000000 0.35613300 0.98626600 0.91052900 +36 +Lattice="8.782773097 0.0 0.0 0.0 8.782773097 0.0 0.0 0.0 5.855182065" Properties=species:S:1:pos:R:3:forces:R:3 class=3 scale=0.9496633517239711 fps_idx=-1 name=457 energy=-298.46974511 stress="-0.2768875068083879 -0.005414652721413622 0.0029666517026235716 -0.005414652721413622 -0.2674046134726871 -0.015330326058394074 0.0029666517026235716 -0.015330326058394074 -0.2529102376975591" free_energy=-298.46273679 pbc="T T T" +Co 5.85518000 5.85518000 2.92759000 0.63243700 -0.01359400 -0.34872300 +Co 1.46380000 4.39139000 4.39139000 -0.70372600 -0.50039500 -0.26274500 +Co 1.46380000 7.31898000 1.46380000 -0.32713300 0.16614900 0.01649300 +Ni 1.46380000 7.31898000 4.39139000 -0.93799400 -0.20898300 -0.38204500 +Ni 0.00000000 2.92759000 2.92759000 -0.01338100 0.05091700 0.60243400 +Ni 4.39139000 4.39139000 1.46380000 1.02426200 0.74788700 -0.39609100 +Ni 0.00000000 2.92759000 0.00000000 0.34718200 -0.26807600 -0.82662200 +Ni 5.85518000 0.00000000 0.00000000 0.97699000 0.02603000 -0.03345200 +Ni 7.31898000 1.46380000 4.39139000 0.23885200 0.34220500 0.59843600 +Cu 0.00000000 5.85518000 2.92759000 1.02987700 0.77767300 -0.19765400 +Cu 7.31898000 7.31898000 1.46380000 -1.12528500 -0.73536500 0.11870100 +Cu 5.85518000 5.85518000 0.00000000 -0.08526800 -0.05343000 1.75210800 +Rh 0.00000000 0.00000000 2.92759000 0.47939100 -1.06462900 0.15976000 +Rh 2.92759000 0.00000000 2.92759000 -1.92185100 -1.35850200 -0.10284100 +Rh 1.46380000 1.46380000 1.46380000 -1.91930200 0.68628700 -0.08622700 +Rh 4.39139000 7.31898000 1.46380000 1.33896700 -0.56895600 -0.24677600 +Rh 5.85518000 2.92759000 0.00000000 0.50824900 -0.57313600 0.32699200 +Rh 1.46380000 1.46380000 4.39139000 -1.43135400 0.99008300 -0.25834800 +Hf 7.31898000 1.46380000 1.46380000 1.03622600 1.88803900 0.51361700 +Hf 7.31898000 7.31898000 4.39139000 -1.76365700 0.34708300 -2.63827200 +Hf 4.39139000 7.31898000 4.39139000 1.25548700 0.78055100 1.26839700 +Hf 0.00000000 5.85518000 0.00000000 3.47981700 0.39839700 1.69715900 +Hf 2.92759000 2.92759000 2.92759000 -2.10900700 0.94306400 1.44668100 +Hf 7.31898000 4.39139000 4.39139000 -0.31173700 -2.52985900 -1.43397400 +Hf 2.92759000 5.85518000 2.92759000 -2.05541800 0.18904700 -1.30321900 +Hf 0.00000000 0.00000000 0.00000000 3.10616500 0.48842800 -0.11358200 +Hf 4.39139000 1.46380000 1.46380000 0.55084000 -1.68864000 -1.80466000 +Ir 2.92759000 5.85518000 0.00000000 -0.90688800 0.06169700 -0.42066800 +Ir 5.85518000 0.00000000 2.92759000 1.70372800 -0.55762300 0.31761500 +Ir 7.31898000 4.39139000 1.46380000 0.80149100 1.07259100 0.28088700 +Ir 4.39139000 1.46380000 4.39139000 0.78974300 -0.41505500 0.56402000 +Ir 5.85518000 2.92759000 2.92759000 0.64432600 0.20875900 -0.42608500 +Ir 2.92759000 2.92759000 0.00000000 -0.29839300 1.51968000 0.29139400 +Ir 4.39139000 4.39139000 4.39139000 0.14495700 1.63451200 -0.40929800 +Ir 2.92759000 0.00000000 0.00000000 -1.39147600 -1.99023200 0.29028900 +Pt 1.46380000 4.39139000 1.46380000 -2.78711800 -0.79260600 1.44629900 +48 +Lattice="7.699206955 0.0 0.0 0.0 7.699206955 0.0 0.0 0.0 11.548810433" Properties=species:S:1:pos:R:3:forces:R:3 name=132 energy=-492.3406702 stress="0.20076053153203025 0.0003938267428249818 0.0006968020788136071 0.0003938267428249818 0.204728052913553 -0.00046455552423949086 0.0006968020788136071 -0.00046455552423949086 0.19849039479880043" free_energy=-492.34422801 pbc="T T T" +Mn 0.00000000 1.92480000 5.77441000 0.37615400 -0.35167600 -0.00738200 +Mn 1.92480000 0.00000000 1.92480000 0.01014900 0.70746300 -0.22709100 +Mn 5.77441000 1.92480000 3.84960000 -0.00675500 -0.63748000 0.10510300 +Mn 5.77441000 3.84960000 1.92480000 0.02357700 -0.68199800 -0.78033100 +Mn 3.84960000 0.00000000 7.69921000 -0.89206400 0.42856000 -0.07030900 +Mn 3.84960000 5.77441000 1.92480000 0.34186600 0.68345900 -0.24820200 +Mn 3.84960000 3.84960000 7.69921000 -0.40315100 -0.38630300 0.03747000 +Mn 1.92480000 3.84960000 5.77441000 0.03795800 -1.09537100 1.09587000 +Mn 0.00000000 0.00000000 3.84960000 -0.01906200 1.35800500 -0.02467400 +Mn 0.00000000 3.84960000 7.69921000 0.46912100 -0.40399500 0.04118600 +Mn 0.00000000 1.92480000 9.62401000 0.26916200 0.01899600 0.24555100 +Mn 1.92480000 0.00000000 5.77441000 0.05510400 1.01565300 0.37939400 +Mn 5.77441000 5.77441000 0.00000000 0.37978500 -0.12199200 -0.04888900 +Mn 1.92480000 5.77441000 7.69921000 -0.01243100 0.26035800 0.08508700 +Mn 0.00000000 3.84960000 0.00000000 -0.51066300 -0.29069400 -0.21307400 +Mn 1.92480000 1.92480000 3.84960000 0.00376900 -0.86223900 0.39197200 +Mn 5.77441000 0.00000000 1.92480000 0.11532800 0.66502600 -0.21637500 +Mn 3.84960000 1.92480000 5.77441000 -0.34168900 -0.41799900 0.11146400 +Mn 0.00000000 0.00000000 7.69921000 0.88754200 0.42212300 -0.01732900 +Mn 0.00000000 5.77441000 9.62401000 0.29511800 -0.00572800 0.26545100 +Mn 5.77441000 3.84960000 9.62401000 0.27213400 -0.07834000 0.40286100 +Mn 3.84960000 5.77441000 9.62401000 -0.34327900 0.00096700 -0.35648600 +Mn 3.84960000 0.00000000 3.84960000 0.05738100 0.99915400 -0.36335800 +Mn 0.00000000 1.92480000 1.92480000 -0.28125700 -0.66237900 -0.28226400 +Mn 3.84960000 1.92480000 1.92480000 0.25278100 -0.58963100 0.22676400 +Mn 5.77441000 1.92480000 0.00000000 0.56795700 0.07757200 0.19660600 +Mn 1.92480000 0.00000000 9.62401000 -0.26151500 -0.01347200 -0.21549200 +Mn 5.77441000 0.00000000 5.77441000 -0.04281200 1.01459500 -0.21414300 +Fe 1.92480000 1.92480000 0.00000000 -0.42701000 -0.35805200 0.03457600 +Fe 3.84960000 1.92480000 9.62401000 -0.25947000 -0.00228400 -0.18020400 +Fe 5.77441000 3.84960000 5.77441000 -0.05192400 -0.85754200 0.34840600 +Fe 1.92480000 5.77441000 0.00000000 -0.30525500 0.44245700 -0.22143700 +Fe 0.00000000 0.00000000 0.00000000 0.16461300 0.17861500 0.06608600 +Fe 1.92480000 1.92480000 7.69921000 0.01891100 -0.17809700 -0.17872800 +W 0.00000000 5.77441000 5.77441000 0.71273000 0.67376700 1.56388400 +W 3.84960000 0.00000000 0.00000000 -0.86865900 0.12589500 0.69110300 +W 3.84960000 3.84960000 3.84960000 0.83454100 -2.37801900 0.02897400 +W 5.77441000 1.92480000 7.69921000 -0.11415500 0.75168300 -0.77724300 +W 3.84960000 3.84960000 0.00000000 1.13608300 -0.10832000 -0.02439400 +W 5.77441000 5.77441000 7.69921000 0.03687900 -0.86703800 0.36429200 +W 3.84960000 5.77441000 5.77441000 -0.70956700 0.70792600 1.38956200 +W 0.00000000 5.77441000 1.92480000 -0.64020700 1.70371200 -2.36013800 +W 5.77441000 5.77441000 3.84960000 -0.81184100 1.53481200 -0.57412500 +W 1.92480000 3.84960000 1.92480000 -0.09772300 -1.55518600 -1.92848400 +W 0.00000000 3.84960000 3.84960000 -0.80107200 -3.23575500 0.85786700 +W 5.77441000 0.00000000 9.62401000 0.77915200 0.10037100 0.93724800 +W 1.92480000 3.84960000 9.62401000 -0.67480300 -0.15918500 -0.50703200 +W 1.92480000 5.77441000 3.84960000 0.77857200 2.42761200 0.17041100 +36 +Lattice="10.17896035 0.0 0.0 0.0 10.17896035 0.0 0.0 0.0 6.785973567" Properties=species:S:1:pos:R:3:forces:R:3 class=3 scale=1.0936659901056818 fps_idx=-1 name=654 energy=-295.5983349 stress="0.23889555934780998 -0.00207471508249835 0.006845643518704877 -0.00207471508249835 0.2384006576062004 0.0015937943213307935 0.006845643518704877 0.0015937943213307935 0.23851844112491494" free_energy=-295.63442729 pbc="T T T" +Ru 8.48247000 1.69649000 5.08948000 -0.12982700 0.36181500 0.24203200 +Ru 3.39299000 6.78597000 3.39299000 -0.38140600 0.11691800 -0.11230400 +Ru 8.48247000 8.48247000 5.08948000 -0.13502500 0.17453200 0.10926100 +Ru 1.69649000 1.69649000 5.08948000 0.38167900 0.10245300 -0.17445700 +Ru 1.69649000 8.48247000 5.08948000 0.57463300 0.09994000 0.02340500 +Ru 8.48247000 5.08948000 5.08948000 -0.04747600 -0.50482100 -0.10396300 +Ru 5.08948000 5.08948000 5.08948000 -0.12825700 -0.18877800 0.12089700 +Ru 0.00000000 3.39299000 0.00000000 -0.17462300 -0.43186100 -0.22794400 +Ru 1.69649000 5.08948000 5.08948000 0.14677400 -0.20374700 -0.33646500 +Ru 1.69649000 1.69649000 1.69649000 -0.01867200 0.09525000 0.29862700 +Ru 5.08948000 1.69649000 1.69649000 0.17631100 -0.02017200 0.17527900 +Ru 0.00000000 6.78597000 3.39299000 -0.03016500 -0.46664600 0.00804200 +Ru 6.78597000 0.00000000 3.39299000 -0.32422200 0.05867500 0.06183000 +Ag 6.78597000 6.78597000 3.39299000 -0.25876600 -0.06087900 0.32704800 +Ag 0.00000000 0.00000000 3.39299000 0.16669900 0.14996000 0.13023000 +Ag 0.00000000 6.78597000 0.00000000 0.44722300 -0.04459100 -0.21752100 +Ag 3.39299000 3.39299000 0.00000000 -0.19921000 0.16848600 0.19999100 +Hf 8.48247000 5.08948000 1.69649000 -0.12812600 -1.30571600 0.08735800 +Hf 3.39299000 0.00000000 0.00000000 -1.00006500 -1.79502600 0.98678900 +Hf 5.08948000 1.69649000 5.08948000 1.45855000 0.79074400 -1.38070300 +Hf 8.48247000 8.48247000 1.69649000 -0.34849900 1.03554700 -0.50399400 +Ir 6.78597000 6.78597000 0.00000000 0.07768500 0.16412800 0.18465300 +Ir 3.39299000 6.78597000 0.00000000 0.28476700 0.14304000 0.04170600 +Ir 8.48247000 1.69649000 1.69649000 -0.32584700 0.18326200 -0.15463400 +Ir 5.08948000 5.08948000 1.69649000 -0.03131100 -0.16695400 -0.07494500 +Ir 6.78597000 3.39299000 3.39299000 0.06062700 -0.13454200 0.13316800 +Ir 6.78597000 0.00000000 0.00000000 0.07864400 0.17272300 -0.15694000 +Ir 0.00000000 0.00000000 0.00000000 -0.42713800 0.36124800 -0.14929400 +Ir 1.69649000 8.48247000 1.69649000 0.37986100 0.01566900 -0.17221400 +Ir 1.69649000 5.08948000 1.69649000 -0.07261900 0.07827600 0.17032800 +Ir 3.39299000 3.39299000 3.39299000 -0.08279200 0.14605800 0.16004900 +Ir 5.08948000 8.48247000 1.69649000 -0.21679000 0.27509300 -0.05191800 +Ir 6.78597000 3.39299000 0.00000000 0.33245300 -0.14503600 -0.11472300 +Ir 0.00000000 3.39299000 3.39299000 -0.08648300 0.51844200 -0.08246900 +Ir 3.39299000 0.00000000 3.39299000 0.35128200 -0.24542700 0.38977100 +Ir 5.08948000 8.48247000 5.08948000 -0.36987100 0.50193500 0.16402600 +36 +Lattice="8.470538482 0.0 0.0 0.0 8.470538482 0.0 0.0 0.0 5.647025655" Properties=species:S:1:pos:R:3:forces:R:3 class=3 scale=0.9140462950723837 fps_idx=-1 name=4599 energy=-232.43740238 stress="-0.5551028322682163 0.010617643385350291 -0.010828831088133675 0.010617643385350291 -0.553487985336604 0.009739993579614217 -0.010828831088133675 0.009739993579614217 -0.5563721180470742" free_energy=-232.44483553 pbc="T T T" +Sc 7.06356000 4.20968000 1.36018000 -0.71162000 -1.09016300 2.76741400 +V 0.01672000 8.42286000 2.84840000 -0.09835200 0.30749700 -1.52562900 +V 5.63776000 5.81834000 2.96876000 0.85988500 -5.65072700 -2.17141700 +V 8.44912000 2.70401000 2.82385000 1.36272600 2.96979900 1.29613600 +V 7.09000000 4.17781000 4.21865000 0.96721700 0.46987100 -1.77875000 +Fe 0.08512000 2.91440000 5.59498000 -2.11761200 -3.25009100 0.34717100 +Fe 8.43908000 5.56381000 2.84604000 -0.63359400 -0.96347200 0.92870800 +Fe 4.35253000 4.25208000 1.54113000 -1.23466800 1.91980300 -0.22488900 +Fe 2.84223000 5.58915000 2.84549000 -0.00008000 0.80950600 -0.96728700 +Fe 2.81595000 8.39327000 2.81460000 -0.31152100 0.68495800 0.94247300 +Zn 4.27857000 1.42065000 4.26019000 -1.52016300 -0.18062800 -0.25500100 +Zn 2.98135000 0.01215000 5.57873000 -2.80748500 -1.99925900 -0.90651300 +Zn 2.78437000 5.63597000 0.12477000 1.55460000 1.11119400 -0.50873900 +Zn 1.38303000 1.38437000 4.24612000 0.50826200 -0.70991600 1.12463600 +Zn 1.37913000 7.13745000 4.05508000 1.33349500 -1.79565400 2.21312200 +Zn 8.39640000 0.15374000 0.04578000 1.16709300 -1.46426600 -1.80520900 +Rh 5.71683000 2.86513000 2.88028000 -3.59406400 2.94564700 -1.09423700 +Rh 4.36513000 7.05475000 4.11228000 -8.03178500 -2.88356200 2.13389500 +Rh 7.11786000 7.21933000 4.12523000 3.44166300 -3.13284900 3.22760400 +Rh 7.04075000 1.49046000 4.11244000 5.49776600 -2.20632800 3.57446100 +Rh 5.60416000 0.03868000 2.96491000 0.21923900 4.04123400 -4.95542500 +Rh 4.31374000 7.07469000 1.50923000 -4.07193500 -1.42833600 -0.93999200 +Rh 5.68337000 2.79417000 5.60472000 0.27032700 1.43269600 1.27176200 +Pd 2.87215000 2.82752000 0.00653000 -0.66375500 -1.69798100 0.19426000 +Pd 1.24012000 6.94337000 1.34469000 3.37882000 2.82304000 0.68041400 +Pd 4.23260000 4.23165000 4.39253000 -1.48400300 2.67374900 -3.32611100 +Pd 5.58330000 0.01089000 5.54795000 2.73698500 0.60196700 1.40577200 +Pd 1.34732000 4.21178000 1.41646000 2.23652700 -0.42942200 3.00877900 +Pd 4.18748000 1.33186000 1.29788000 0.52781000 -0.46777600 3.51233100 +Hf 6.98023000 6.92740000 1.39853000 2.42551700 5.21544500 2.35890800 +Hf 6.99595000 1.50886000 1.45391000 0.34210500 -2.77586800 -2.48987000 +Hf 8.43312000 5.45272000 0.03649000 0.20167500 3.87916900 -4.45817700 +Hf 5.57737000 5.67304000 5.60593000 -0.98899200 -1.42553300 0.79527200 +Hf 2.92739000 2.77952000 2.75245000 0.67670700 2.00634900 0.60049400 +Pt 1.42862000 1.36752000 1.45982000 -1.50382300 -1.85238900 -0.84318000 +Pt 1.46312000 4.21481000 4.30096000 0.06503600 1.51229500 -4.13318500 +48 +Lattice="8.182103461 0.0 0.0 0.0 8.182103461 0.0 0.0 0.0 12.273155192" Properties=species:S:1:pos:R:3:forces:R:3 class=3 scale=0.9498863707029109 fps_idx=-1 name=2510 energy=-352.70544158 stress="-0.11205496801617043 -0.003212348709363966 0.007141647088999891 -0.003212348709363966 -0.12916693285454786 -0.011578623579426031 0.007141647088999891 -0.011578623579426031 -0.12328456640527706" free_energy=-352.70655496 pbc="T T T" +Ti 6.07931000 6.15929000 4.03516000 0.86877500 -1.56169600 -0.81191300 +Ti 0.00293000 8.11661000 8.18260000 -1.33223800 -0.88222000 1.77454300 +Ti 8.17795000 3.97600000 8.22427000 -0.53979700 0.84738500 -1.10285100 +Ti 0.06540000 4.15437000 4.14571000 -0.90895500 -1.25799200 -1.40517100 +Ti 3.99524000 4.10160000 4.10883000 0.97886800 -1.80310400 0.12705100 +Ti 4.17200000 2.00901000 2.11964000 1.11428500 1.52070300 1.31999000 +Fe 6.16487000 4.08145000 2.02554000 0.78506700 0.00049900 0.64427700 +Fe 6.26746000 0.01908000 10.32914000 -0.10338800 0.10852600 -0.83000900 +Fe 0.01862000 4.13851000 12.18130000 -0.56093100 -0.53258400 0.82349200 +Fe 2.07800000 2.11482000 4.18724000 0.51563900 0.37520800 0.19093300 +Fe 1.94952000 4.20131000 5.90533000 0.12927500 -0.68762800 0.34013600 +Fe 2.02241000 6.14658000 8.25776000 -0.42131700 -0.28048200 -0.65316500 +Fe 6.10795000 2.10450000 8.20943000 0.43636800 -0.16403800 -0.00424100 +Fe 6.08072000 6.08797000 8.14604000 -0.06797500 0.11946000 0.61889500 +Fe 6.25459000 1.99802000 0.04060000 0.04963300 0.55046100 -0.10681700 +Fe 6.16302000 2.07005000 4.08680000 -0.60655100 0.53305800 -0.31721600 +Fe 0.01753000 6.06919000 2.03629000 -0.60322400 -0.65250700 0.14070700 +Y 4.12581000 0.10505000 0.09457000 0.80491400 -0.22312000 -0.38651500 +Y 1.97156000 0.05513000 2.18999000 0.18540500 -0.00286300 0.06330500 +Y 6.13081000 6.17762000 12.19179000 1.11087400 -0.37505900 0.72909900 +Y 1.98358000 6.03523000 4.16697000 -0.58503100 0.15197000 -0.52383800 +Y 0.00181000 1.97054000 6.08471000 -0.16822600 2.06123900 -0.47612300 +Y 8.11645000 6.07315000 6.21688000 0.75253700 -0.62994800 1.68300500 +Y 4.21643000 4.04821000 0.10661000 1.79894400 -1.16574700 0.20136000 +Y 2.05583000 0.05831000 10.25821000 -0.96296700 -0.95357500 -1.11133600 +Y 6.21120000 8.15295000 6.17183000 -1.71006100 0.54362400 0.62271600 +Y 8.18034000 6.01042000 10.24525000 -0.82788100 0.54755700 -1.75278900 +Y 4.08030000 4.04863000 8.25740000 0.23046200 0.36816300 -2.14277800 +Y 0.14204000 0.03016000 0.00589000 -1.68623800 -0.63646700 -0.31242200 +Y 4.11090000 1.97359000 10.33084000 2.67378300 -0.54574300 -0.75713700 +Y 6.07663000 8.12095000 2.03710000 1.73136400 0.59029600 0.62307300 +Y 1.98000000 1.94504000 8.14578000 -0.37478700 1.88758200 -0.01435500 +Y 4.05783000 6.05414000 2.07640000 1.11750500 -0.80977800 0.40469400 +Y 2.10470000 4.13135000 1.92365000 -2.93525800 -1.29050100 1.42612600 +Y 8.07673000 1.89777000 2.04405000 -1.30710000 4.02333600 0.74248800 +Zr 6.16898000 4.19227000 6.07591000 -0.26121400 -1.33786700 0.16652700 +Zr 4.05257000 6.08972000 6.08519000 -0.16383100 -0.66216100 0.95210400 +Zr 4.06046000 2.10598000 6.07524000 0.20631100 0.59743600 -1.60142200 +Zr 2.14282000 4.23976000 10.19569000 -1.67981400 -1.10494700 -0.14605900 +Zr 6.06226000 4.06691000 10.33828000 1.80157600 -1.30937400 -0.91261900 +Zr 4.00108000 8.16450000 8.17074000 2.86559000 -0.83650700 0.47508000 +Zr 2.12539000 6.09336000 12.23264000 -1.16917500 0.23664600 0.64279600 +Zr 4.09772000 8.10282000 3.93642000 0.20825100 1.96857300 2.32380800 +Zr 2.14156000 2.10078000 0.02299000 -1.43235800 1.14352300 -0.20156600 +Zr 4.05712000 6.16297000 10.10970000 1.53670700 1.89256300 -0.53250200 +Zr 8.04705000 2.07642000 10.17236000 -1.40950600 0.17752400 0.86417600 +Zr 0.04437000 8.17222000 4.16041000 -0.41756300 -0.22497900 -1.08483300 +Zr 2.05366000 8.10143000 6.27713000 0.33325000 -0.31444200 -0.71270200 +36 +Lattice="9.883640051 0.0 0.0 0.0 9.883640051 0.0 0.0 0.0 6.589093367" Properties=species:S:1:pos:R:3:forces:R:3 name=527 energy=-171.17145311 stress="0.13263293649478775 0.0015481251990567058 -0.001552544187517831 0.0015481251990567058 0.13087043418178432 -0.0017256149940694482 -0.001552544187517831 -0.0017256149940694482 0.1296045812329457" free_energy=-171.2460493 pbc="T T T" +Ti 4.89143000 8.23478000 5.05098000 0.57862800 0.42604200 0.34052700 +Ti 9.87294000 0.03866000 0.03729000 0.43869700 0.36324800 0.59601000 +Ti 4.85292000 1.64484000 4.97127000 1.13751700 0.13176000 0.13112400 +Co 5.07027000 4.94643000 5.12986000 -0.16756400 -0.59735600 -0.88502600 +Co 8.14137000 1.52126000 1.76087000 0.81293800 0.04966400 0.07060300 +Co 9.84968000 6.74687000 3.37568000 0.52191800 0.37189500 -0.41762100 +Co 8.17908000 1.48837000 4.84918000 0.55991000 0.19618300 0.23507900 +Co 1.75338000 1.53479000 5.00266000 -1.03709800 0.21075500 -0.01777500 +Co 6.61490000 9.85740000 6.51423000 -0.31378800 0.36901600 -0.49094000 +Co 9.85835000 3.21962000 0.00376000 0.09747500 -0.88558100 0.00827300 +Co 0.06086000 6.54001000 0.03242000 0.10241300 1.07896300 0.19957300 +Co 1.56560000 8.38701000 1.67967000 -1.32177700 0.13335200 0.01444700 +Co 6.54694000 3.34589000 3.18848000 -0.39547400 -0.84240900 0.39284300 +Co 8.20098000 8.23716000 1.52705000 0.94468700 0.31992800 -0.34519200 +Co 9.88332000 3.48283000 3.33143000 -0.26977300 -1.06895500 -0.15108700 +Co 0.07841000 0.19616000 3.10368000 0.21241700 0.21671900 -0.31093700 +Co 4.85281000 1.66007000 1.56779000 0.40990000 -0.07217200 0.10313100 +Co 1.72314000 1.55411000 1.62807000 -1.10408100 -0.27142400 0.00568200 +Co 1.56151000 8.23161000 4.95778000 -0.70074600 -0.08723900 0.03510500 +Ag 3.40556000 6.56819000 3.24242000 -0.02857600 0.58145100 1.01078900 +Ag 8.21158000 4.91596000 4.93151000 0.56608500 -0.48261500 0.03282700 +Ag 6.70010000 6.50587000 3.31648000 -0.66927000 0.54460200 0.43538200 +Ag 6.66474000 3.09976000 0.04300000 0.09615600 -0.21960600 0.02482200 +Ag 3.20348000 3.16919000 6.57527000 -0.13323000 0.28096100 0.08876500 +Ag 1.85178000 4.99397000 1.89072000 -1.70297600 -0.22670900 -0.85133600 +Ag 6.69395000 0.06959000 3.29218000 -0.47879000 -0.20103000 -0.13586100 +Ag 8.34931000 8.28672000 4.89456000 0.07447700 0.05393400 0.26185400 +Ag 4.99244000 8.21976000 1.73138000 0.22351000 0.50653600 -0.35627500 +Ag 3.20156000 0.07012000 3.40094000 -0.13686800 -0.30965700 -0.42992700 +Ag 4.91896000 4.96943000 1.54659000 0.44820000 -1.06410400 0.52220200 +Ag 3.35163000 6.45008000 0.04342000 -0.74554400 0.78062500 -0.65683400 +Ag 6.52382000 6.43099000 6.51782000 0.29525100 1.27218200 -0.15934500 +Ag 1.49305000 4.99444000 4.93981000 -0.17346600 0.03636500 0.03964300 +Ag 8.04772000 5.07815000 1.60302000 1.49569600 -0.80272000 0.04549400 +Ag 3.18478000 0.00575000 6.53131000 0.00262300 -0.28013000 0.42936900 +Ag 3.28501000 3.32146000 3.36929000 0.36052400 -0.51247600 0.18461300 +48 +Lattice="7.17941738 0.0 0.0 0.0 7.17941738 0.0 0.0 0.0 10.76912607" Properties=species:S:1:pos:R:3:forces:R:3 name=712 energy=-297.34150755 stress="-0.19552168339816534 -0.020412056683033083 -0.008598065794469993 -0.020412056683033083 -0.20278676257163872 -0.007518478202475092 -0.008598065794469993 -0.007518478202475092 -0.19521058785880394" free_energy=-297.34652158 pbc="T T T" +V 7.11718000 7.11129000 7.18645000 1.79158300 0.33415000 1.92658200 +V 1.84239000 1.74629000 3.62451000 -1.57853500 1.32765800 0.59399900 +Co 7.00040000 1.72838000 1.66946000 2.39490300 0.11479600 2.90189300 +Ni 5.30558000 7.15008000 8.94637000 0.80961600 -0.05160300 -0.96203300 +Ni 5.32352000 3.59696000 8.98374000 -0.39508900 0.50432400 -0.45463700 +Ni 3.57677000 1.77579000 8.95092000 -0.52971800 0.53371300 -0.76224800 +Ni 3.61419000 3.62694000 7.19572000 -0.45087100 -0.58329600 0.43082400 +Ni 7.03560000 7.14007000 3.72586000 0.97577400 0.52731200 -1.42999400 +Ni 3.62360000 5.36291000 8.86829000 0.38819100 0.52322800 0.72481000 +Ni 1.80336000 3.75544000 8.97631000 1.05306300 -1.75691700 0.18744200 +Ni 0.14380000 5.39523000 8.91598000 -2.34231200 0.30886000 0.52227100 +Ni 3.57701000 5.29785000 5.44087000 0.22962200 0.49758200 0.19986700 +Ni 1.81301000 7.10648000 8.96311000 0.06175500 1.32211000 -0.20679200 +Ni 5.40925000 1.85474000 3.66549000 0.49396800 -0.75501800 -0.77806000 +Ni 0.00978000 3.52374000 3.68158000 -0.21729000 0.49814500 -1.65871100 +Ni 3.52386000 0.10375000 7.08630000 0.49194400 -0.76310600 1.05212200 +Zn 5.35152000 5.38908000 0.07003000 0.65874800 0.08725900 -3.03859000 +Zn 1.71810000 1.78022000 10.76614000 1.45147700 0.15531600 -0.03485700 +Zn 5.43885000 5.36625000 7.14949000 -1.53760900 0.09137600 1.34935700 +Zn 5.40297000 0.04311000 5.41880000 -1.34142400 -0.91455200 -0.38931700 +Zn 5.38307000 1.88036000 7.08385000 -1.95693800 -2.01589500 3.46932300 +Zn 0.08957000 0.04350000 0.01047000 -0.46440400 -0.70785500 -1.08040500 +Zn 1.72689000 1.83419000 7.04878000 3.14855900 -1.93362600 4.37125200 +Zn 1.83785000 0.07951000 2.00092000 -2.26467200 -0.39580300 -2.50355000 +Zn 1.74631000 3.64878000 1.82096000 -1.12646900 -1.39356600 0.39805900 +Zn 0.03932000 1.84746000 8.99739000 -0.03015200 -1.19657000 -0.63139300 +Zn 7.12077000 5.36361000 5.38609000 0.66458300 -0.46866600 -0.14865200 +Zn 3.64047000 1.74866000 5.43851000 -0.63977700 0.53187900 -0.17942500 +Zn 1.77344000 7.07163000 5.35802000 0.54077300 0.27867600 0.69147200 +Ru 7.15192000 3.49530000 7.23537000 0.57050400 4.24848300 1.90585800 +Ru 3.51936000 5.44304000 1.76501000 0.41417900 0.01168600 1.18428900 +Ru 5.37670000 5.33852000 3.56367000 0.02034200 0.46416400 1.02338100 +Ru 3.58858000 3.61462000 0.12729000 -2.76913000 0.45009300 -4.69317100 +Ru 3.62656000 0.04205000 10.72898000 -2.27418300 -0.89703100 -1.11254400 +Ru 3.65283000 0.07619000 3.54955000 0.20058700 -1.08409600 2.16137000 +Ru 1.85735000 5.39837000 3.64583000 -1.19351700 -1.49949000 -0.21424300 +Ru 5.34272000 3.54244000 5.44358000 -0.07257600 1.74005600 -1.58712900 +Ru 3.58709000 1.90088000 1.89881000 -0.88839900 -2.63735200 -1.45169500 +Ru 0.01587000 5.39409000 1.77150000 0.82901100 1.60199700 1.76215500 +Ru 1.85439000 5.42581000 7.30256000 -0.90550100 -0.11805500 -2.15382400 +Ru 1.68721000 3.43545000 5.38194000 3.99508700 5.03680400 -1.33306000 +Ru 1.82443000 5.40731000 10.73790000 0.27205500 0.28496700 -0.63418800 +Ru 5.31221000 0.01647000 1.67253000 1.42779700 -0.19844200 3.09981000 +Ru 3.58128000 3.55006000 3.55187000 -1.21577400 1.54090900 3.07020300 +Pt 5.31998000 3.59215000 1.78596000 3.88356500 0.53526800 2.02880600 +Pt 7.17476000 3.57228000 10.75546000 1.10117300 1.11831900 -1.20907700 +Pt 0.03858000 1.80205000 5.48368000 -2.88233000 -4.56367800 -3.73236800 +Pt 5.37761000 1.80432000 10.72731000 -0.79219000 -0.73450900 -2.67517700 diff --git a/tests/resources/bpnn-model.pt b/tests/resources/bpnn-model.pt index 2996c7e88..95a2d3aa9 100644 Binary files a/tests/resources/bpnn-model.pt and b/tests/resources/bpnn-model.pt differ diff --git a/tests/utils/test_compute_loss.py b/tests/utils/test_compute_loss.py new file mode 100644 index 000000000..b4e4c29fb --- /dev/null +++ b/tests/utils/test_compute_loss.py @@ -0,0 +1,104 @@ +from pathlib import Path + +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.atomistic import ModelCapabilities, ModelOutput + +from metatensor.models import soap_bpnn +from metatensor.models.utils.compute_loss import compute_model_loss +from metatensor.models.utils.data import read_structures +from metatensor.models.utils.loss import TensorMapDictLoss + + +RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / "resources" + + +def test_compute_model_loss(): + """Test that the model loss is computed.""" + + loss_fn = TensorMapDictLoss( + weights={ + "energy": {"values": 1.0, "positions": 10.0}, + } + ) + + capabilities = ModelCapabilities( + length_unit="Angstrom", + species=[21, 23, 24, 27, 29, 39, 40, 41, 72, 74, 78], + outputs={ + "energy": ModelOutput( + quantity="energy", + unit="eV", + ) + }, + ) + + model = soap_bpnn.Model(capabilities) + # model = torch.jit.script(model) # jit the model for good measure + + structures = read_structures(RESOURCES_PATH / "alchemical_reduced_10.xyz")[:2] + + gradient_samples = Labels( + names=["sample", "atom"], + values=torch.stack( + [ + torch.concatenate( + [ + torch.tensor([i] * len(structure)) + for i, structure in enumerate(structures) + ] + ), + torch.concatenate( + [torch.arange(len(structure)) for structure in structures] + ), + ], + dim=1, + ), + ) + + gradient_components = [ + Labels( + names=["coordinate"], + values=torch.tensor([[0], [1], [2]]), + ) + ] + + block = TensorBlock( + values=torch.tensor([[0.0] * len(structures)]).T, + samples=Labels.range("structure", len(structures)), + components=[], + properties=Labels.single(), + ) + + block.add_gradient( + "positions", + TensorBlock( + values=torch.tensor( + [ + [[1.0], [1.0], [1.0]] + for structure in structures + for _ in range(len(structure.positions)) + ] + ), + samples=gradient_samples, + components=gradient_components, + properties=Labels.single(), + ), + ) + + targets = { + "energy": TensorMap( + keys=Labels( + names=["lambda", "sigma"], + values=torch.tensor([[0, 1]]), + ), + blocks=[block], + ), + } + + compute_model_loss( + loss_fn, + model, + structures, + targets, + ) diff --git a/tests/utils/test_loss.py b/tests/utils/test_loss.py new file mode 100644 index 000000000..e0d11f21f --- /dev/null +++ b/tests/utils/test_loss.py @@ -0,0 +1,240 @@ +from pathlib import Path + +import pytest +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + +from metatensor.models.utils.loss import TensorMapDictLoss, TensorMapLoss + + +RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / "resources" + + +@pytest.fixture +def tensor_map_with_grad_1(): + block = TensorBlock( + values=torch.tensor([[1.0], [2.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels.single(), + ) + block.add_gradient( + "gradient", + TensorBlock( + values=torch.tensor([[1.0], [2.0], [3.0]]), + samples=Labels.range("sample", 3), + components=[], + properties=Labels.single(), + ), + ) + tensor_map = TensorMap(keys=Labels.single(), blocks=[block]) + return tensor_map + + +@pytest.fixture +def tensor_map_with_grad_2(): + block = TensorBlock( + values=torch.tensor([[1.0], [1.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels.single(), + ) + block.add_gradient( + "gradient", + TensorBlock( + values=torch.tensor([[1.0], [0.0], [3.0]]), + samples=Labels.range("sample", 3), + components=[], + properties=Labels.single(), + ), + ) + tensor_map = TensorMap(keys=Labels.single(), blocks=[block]) + return tensor_map + + +@pytest.fixture +def tensor_map_with_grad_3(): + block = TensorBlock( + values=torch.tensor([[0.0], [1.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels.single(), + ) + block.add_gradient( + "gradient", + TensorBlock( + values=torch.tensor([[1.0], [0.0], [3.0]]), + samples=Labels.range("sample", 3), + components=[], + properties=Labels.single(), + ), + ) + tensor_map = TensorMap(keys=Labels.single(), blocks=[block]) + return tensor_map + + +@pytest.fixture +def tensor_map_with_grad_4(): + block = TensorBlock( + values=torch.tensor([[0.0], [1.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels.single(), + ) + block.add_gradient( + "gradient", + TensorBlock( + values=torch.tensor([[1.0], [0.0], [2.0]]), + samples=Labels.range("sample", 3), + components=[], + properties=Labels.single(), + ), + ) + tensor_map = TensorMap(keys=Labels.single(), blocks=[block]) + return tensor_map + + +def test_tmap_loss_no_gradients(): + """Test that the loss is computed correctly when there are no gradients.""" + loss = TensorMapLoss() + + tensor_map_1 = TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.tensor([[1.0], [2.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels.single(), + ) + ], + ) + tensor_map_2 = TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.tensor([[0.0], [2.0], [3.0]]), + samples=Labels.range("samples", 3), + components=[], + properties=Labels.single(), + ) + ], + ) + + assert torch.allclose(loss(tensor_map_1, tensor_map_1), torch.tensor(0.0)) + + # Expected result: 1.0/3.0 (there are three values) + assert torch.allclose(loss(tensor_map_1, tensor_map_2), torch.tensor(1.0 / 3.0)) + + +def test_tmap_loss_with_gradients(tensor_map_with_grad_1, tensor_map_with_grad_2): + """Test that the loss is computed correctly when there are gradients.""" + loss = TensorMapLoss(gradient_weights={"gradient": 0.5}) + + assert torch.allclose( + loss(tensor_map_with_grad_1, tensor_map_with_grad_1), torch.tensor(0.0) + ) + + # Expected result: 1.0/3.0 + 0.5 * 4.0 / 3.0 (there are three values) + assert torch.allclose( + loss(tensor_map_with_grad_1, tensor_map_with_grad_2), + torch.tensor(1.0 / 3.0 + 0.5 * 4.0 / 3.0), + ) + + +def test_tmap_dict_loss( + tensor_map_with_grad_1, + tensor_map_with_grad_2, + tensor_map_with_grad_3, + tensor_map_with_grad_4, +): + """Test that the dict loss is computed correctly.""" + + loss = TensorMapDictLoss( + weights={ + "output_1": {"values": 1.0, "gradient": 0.5}, + "output_2": {"values": 1.0, "gradient": 0.5}, + } + ) + + output_dict = { + "output_1": tensor_map_with_grad_1, + "output_2": tensor_map_with_grad_2, + } + + target_dict = { + "output_1": tensor_map_with_grad_3, + "output_2": tensor_map_with_grad_4, + } + + expected_result = ( + 1.0 + * ( + tensor_map_with_grad_1.block().values + - tensor_map_with_grad_3.block().values + ) + .pow(2) + .mean() + + 0.5 + * ( + tensor_map_with_grad_1.block().gradient("gradient").values + - tensor_map_with_grad_3.block().gradient("gradient").values + ) + .pow(2) + .mean() + + 1.0 + * ( + tensor_map_with_grad_2.block().values + - tensor_map_with_grad_4.block().values + ) + .pow(2) + .mean() + + 0.5 + * ( + tensor_map_with_grad_2.block().gradient("gradient").values + - tensor_map_with_grad_4.block().gradient("gradient").values + ) + .pow(2) + .mean() + ) + + assert torch.allclose(loss(output_dict, target_dict), expected_result) + + +def test_tmap_dict_loss_subset(tensor_map_with_grad_1, tensor_map_with_grad_3): + """Test that the dict loss is computed correctly when only a subset + of the possible targets is present both in outputs and targets.""" + + loss = TensorMapDictLoss( + weights={ + "output_1": {"values": 1.0, "gradient": 0.5}, + "output_2": {"values": 1.0, "gradient": 0.5}, + } + ) + + output_dict = { + "output_1": tensor_map_with_grad_1, + } + + target_dict = { + "output_1": tensor_map_with_grad_3, + } + + expected_result = ( + 1.0 + * ( + tensor_map_with_grad_1.block().values + - tensor_map_with_grad_3.block().values + ) + .pow(2) + .mean() + + 0.5 + * ( + tensor_map_with_grad_1.block().gradient("gradient").values + - tensor_map_with_grad_3.block().gradient("gradient").values + ) + .pow(2) + .mean() + ) + + assert torch.allclose(loss(output_dict, target_dict), expected_result) diff --git a/tests/utils/test_model_io.py b/tests/utils/test_model_io.py index 75ba357b6..36e738ba4 100644 --- a/tests/utils/test_model_io.py +++ b/tests/utils/test_model_io.py @@ -30,12 +30,14 @@ def test_save_load_model(monkeypatch, tmp_path): model = soap_bpnn.Model(capabilities) structures = read_structures(RESOURCES_PATH / "qm9_reduced_100.xyz") - output_before_save = model(rascaline.torch.systems_to_torch(structures)) + output_before_save = model(rascaline.torch.systems_to_torch(structures), ["energy"]) save_model(model, "test_model.pt") loaded_model = load_model("test_model.pt") - output_after_load = loaded_model(rascaline.torch.systems_to_torch(structures)) + output_after_load = loaded_model( + rascaline.torch.systems_to_torch(structures), ["energy"] + ) assert metatensor.torch.allclose( output_before_save["energy"], output_after_load["energy"] diff --git a/tests/utils/test_output_gradient.py b/tests/utils/test_output_gradient.py new file mode 100644 index 000000000..528460062 --- /dev/null +++ b/tests/utils/test_output_gradient.py @@ -0,0 +1,203 @@ +from pathlib import Path + +import metatensor.torch +import pytest +import rascaline.torch +import torch +from metatensor.torch.atomistic import ModelCapabilities, ModelOutput + +from metatensor.models import soap_bpnn +from metatensor.models.utils.data import read_structures +from metatensor.models.utils.output_gradient import compute_gradient + + +RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / "resources" + + +@pytest.mark.parametrize("is_training", [True, False]) +def test_forces(is_training): + """Test that the forces are calculated correctly""" + + capabilities = ModelCapabilities( + length_unit="Angstrom", + species=[1, 6, 7, 8], + outputs={ + "energy": ModelOutput( + quantity="energy", + unit="eV", + ) + }, + ) + + model = soap_bpnn.Model(capabilities) + structures = read_structures(RESOURCES_PATH / "qm9_reduced_100.xyz")[:5] + structures = rascaline.torch.systems_to_torch( + structures, positions_requires_grad=True + ) + output = model(structures) + position_gradients = compute_gradient( + output["energy"].block().values, + [structure.positions for structure in structures], + is_training=is_training, + ) + forces = [-position_gradient for position_gradient in position_gradients] + + jitted_model = torch.jit.script(model) + structures = rascaline.torch.systems_to_torch( + structures, positions_requires_grad=True + ) + output = jitted_model(structures) + jitted_position_gradients = compute_gradient( + output["energy"].block().values, + [structure.positions for structure in structures], + is_training=is_training, + ) + jitted_forces = [ + -position_gradient for position_gradient in jitted_position_gradients + ] + + for f, jf in zip(forces, jitted_forces): + assert torch.allclose(f, jf) + + +@pytest.mark.parametrize("is_training", [True, False]) +def test_virial(is_training): + """Test that the virial is calculated correctly""" + + capabilities = ModelCapabilities( + length_unit="Angstrom", + species=[21, 23, 24, 27, 29, 39, 40, 41, 72, 74, 78], + outputs={ + "energy": ModelOutput( + quantity="energy", + unit="eV", + ) + }, + ) + + model = soap_bpnn.Model(capabilities) + structures = read_structures(RESOURCES_PATH / "alchemical_reduced_10.xyz")[:2] + + displacements = [ + torch.eye( + 3, requires_grad=True, dtype=system.cell.dtype, device=system.cell.device + ) + for system in structures + ] + systems = [ + metatensor.torch.atomistic.System( + positions=system.positions @ displacement, + cell=system.cell @ displacement, + species=system.species, + ) + for system, displacement in zip(structures, displacements) + ] + + output = model(systems) + displacement_gradients = compute_gradient( + output["energy"].block().values, + displacements, + is_training=is_training, + ) + virial = [-cell_gradient for cell_gradient in displacement_gradients] + + jitted_model = torch.jit.script(model) + + displacements = [ + torch.eye( + 3, requires_grad=True, dtype=system.cell.dtype, device=system.cell.device + ) + for system in structures + ] + systems = [ + metatensor.torch.atomistic.System( + positions=system.positions @ displacement, + cell=system.cell @ displacement, + species=system.species, + ) + for system, displacement in zip(structures, displacements) + ] + + output = jitted_model(systems) + jitted_displacement_gradients = compute_gradient( + output["energy"].block().values, + displacements, + is_training=is_training, + ) + jitted_virial = [-cell_gradient for cell_gradient in jitted_displacement_gradients] + + for v, jv in zip(virial, jitted_virial): + assert torch.allclose(v, jv) + + +@pytest.mark.parametrize("is_training", [True, False]) +def test_both(is_training): + """Test that the forces and virial are calculated correctly together""" + + capabilities = ModelCapabilities( + length_unit="Angstrom", + species=[21, 23, 24, 27, 29, 39, 40, 41, 72, 74, 78], + outputs={ + "energy": ModelOutput( + quantity="energy", + unit="eV", + ) + }, + ) + + model = soap_bpnn.Model(capabilities) + structures = read_structures(RESOURCES_PATH / "alchemical_reduced_10.xyz")[:2] + + # Here we re-create displacements and systems, otherwise torch + # complains that the graph has already beeen freed in the last grad call + displacements = [ + torch.eye( + 3, requires_grad=True, dtype=system.cell.dtype, device=system.cell.device + ) + for system in structures + ] + systems = [ + metatensor.torch.atomistic.System( + positions=system.positions @ displacement, + cell=system.cell @ displacement, + species=system.species, + ) + for system, displacement in zip(structures, displacements) + ] + + output = model(systems) + print(output["energy"].block().values.requires_grad) + gradients = compute_gradient( + output["energy"].block().values, + [system.positions for system in systems] + displacements, + is_training=is_training, + ) + f_and_v = [-gradient for gradient in gradients] + + displacements = [ + torch.eye( + 3, requires_grad=True, dtype=system.cell.dtype, device=system.cell.device + ) + for system in structures + ] + systems = [ + metatensor.torch.atomistic.System( + positions=system.positions @ displacement, + cell=system.cell @ displacement, + species=system.species, + ) + for system, displacement in zip(structures, displacements) + ] + + jitted_model = torch.jit.script(model) + output = jitted_model(systems) + print(output["energy"].block().values.requires_grad) + jitted_gradients = compute_gradient( + output["energy"].block().values, + [system.positions for system in systems] + displacements, + is_training=is_training, + ) + jitted_f_and_v = [-jitted_gradient for jitted_gradient in jitted_gradients] + + for fv, jfv in zip(f_and_v, jitted_f_and_v): + assert torch.allclose(fv, jfv)