From b8cd9ba167b2590e0ad98445f7896300deae7d09 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Sat, 26 Oct 2024 13:56:24 +0200 Subject: [PATCH 01/12] Add sample `TensorMap` to `TargetInfo` --- src/metatrain/cli/eval.py | 71 ++- src/metatrain/cli/train.py | 18 +- .../alchemical_model/tests/test_exported.py | 7 +- .../tests/test_functionality.py | 7 +- .../alchemical_model/tests/test_invariance.py | 7 +- .../alchemical_model/tests/test_regression.py | 6 +- .../test_torch_alchemical_compatibility.py | 7 +- .../tests/test_torchscript.py | 11 +- .../experimental/alchemical_model/trainer.py | 15 +- .../experimental/gap/tests/test_errors.py | 10 +- .../experimental/gap/tests/test_regression.py | 21 +- .../gap/tests/test_torchscript.py | 15 +- .../experimental/pet/tests/test_exported.py | 7 +- .../pet/tests/test_functionality.py | 15 +- .../pet/tests/test_pet_compatibility.py | 7 +- .../pet/tests/test_torchscript.py | 11 +- src/metatrain/experimental/soap_bpnn/model.py | 6 +- .../soap_bpnn/tests/test_continue.py | 9 +- .../soap_bpnn/tests/test_exported.py | 7 +- .../soap_bpnn/tests/test_functionality.py | 19 +- .../soap_bpnn/tests/test_invariance.py | 7 +- .../soap_bpnn/tests/test_regression.py | 7 +- .../soap_bpnn/tests/test_torchscript.py | 15 +- .../experimental/soap_bpnn/trainer.py | 15 +- src/metatrain/utils/additive/remove.py | 8 +- src/metatrain/utils/data/__init__.py | 2 - src/metatrain/utils/data/dataset.py | 252 ++++++----- src/metatrain/utils/data/extract_targets.py | 48 -- src/metatrain/utils/data/get_dataset.py | 8 +- src/metatrain/utils/data/readers/readers.py | 46 +- src/metatrain/utils/evaluate_model.py | 6 +- src/metatrain/utils/llpr.py | 9 +- src/metatrain/utils/metrics.py | 4 +- src/metatrain/utils/testing.py | 73 ++++ tests/cli/test_eval_model.py | 5 +- tests/cli/test_export_model.py | 9 +- tests/cli/test_train_model.py | 1 + tests/utils/data/test_dataset.py | 409 ++++++++---------- tests/utils/data/test_readers.py | 4 +- tests/utils/test_additive.py | 73 ++-- tests/utils/test_evaluate_model.py | 3 +- tests/utils/test_export.py | 15 +- tests/utils/test_external_naming.py | 7 +- tests/utils/test_output_gradient.py | 12 +- 44 files changed, 738 insertions(+), 576 deletions(-) delete mode 100644 src/metatrain/utils/data/extract_targets.py create mode 100644 src/metatrain/utils/testing.py diff --git a/src/metatrain/cli/eval.py b/src/metatrain/cli/eval.py index 1a9de5a42..25bcbdb2c 100644 --- a/src/metatrain/cli/eval.py +++ b/src/metatrain/cli/eval.py @@ -15,7 +15,6 @@ from ..utils.data import ( Dataset, TargetInfo, - TargetInfoDict, collate_fn, read_systems, read_targets, @@ -159,7 +158,7 @@ def _concatenate_tensormaps( def _eval_targets( model: Union[MetatensorAtomisticModel, torch.jit._script.RecursiveScriptModule], dataset: Union[Dataset, torch.utils.data.Subset], - options: TargetInfoDict, + options: Dict[str, TargetInfo], return_predictions: bool, check_consistency: bool = False, ) -> Optional[Dict[str, TensorMap]]: @@ -335,17 +334,17 @@ def eval_model( # (but we don't/can't calculate RMSEs) # TODO: allow the user to specify which outputs to evaluate eval_targets = {} - eval_info_dict = TargetInfoDict() - gradients = ["positions"] - if all(not torch.all(system.cell == 0) for system in eval_systems): - # only add strain if all structures have cells - gradients.append("strain") + eval_info_dict = {} + do_strain_grad = all( + not torch.all(system.cell == 0) for system in eval_systems + ) + layout = _get_energy_layout(do_strain_grad) # TODO: layout from the user for key in model.capabilities().outputs.keys(): eval_info_dict[key] = TargetInfo( quantity=model.capabilities().outputs[key].quantity, unit=model.capabilities().outputs[key].unit, - per_atom=False, # TODO: allow the user to specify this - gradients=gradients, + # TODO: allow the user to specify whether per-atom or not + layout=layout, ) eval_dataset = Dataset.from_dict({"system": eval_systems, **eval_targets}) @@ -368,3 +367,57 @@ def eval_model( capabilities=model.capabilities(), predictions=predictions, ) + + +def _get_energy_layout(strain_gradient: bool) -> TensorMap: + block = TensorBlock( + values=torch.empty(0, 1), + samples=Labels( + names=["system"], + values=torch.empty((0, 1), dtype=torch.int32), + ), + components=[], + properties=Labels.range("energy", 1), + ) + position_gradient_block = TensorBlock( + values=torch.empty(0, 3, 1), + samples=Labels( + names=["sample", "atom"], + values=torch.empty((0, 2), dtype=torch.int32), + ), + components=[ + Labels( + names=["xyz"], + values=torch.arange(3, dtype=torch.int32).reshape(-1, 1), + ), + ], + properties=Labels.range("energy", 1), + ) + block.add_gradient("positions", position_gradient_block) + + if strain_gradient: + strain_gradient_block = TensorBlock( + values=torch.empty(0, 3, 3, 1), + samples=Labels( + names=["sample", "atom"], + values=torch.empty((0, 2), dtype=torch.int32), + ), + components=[ + Labels( + names=["xyz_1"], + values=torch.arange(3, dtype=torch.int32).reshape(-1, 1), + ), + Labels( + names=["xyz_2"], + values=torch.arange(3, dtype=torch.int32).reshape(-1, 1), + ), + ], + properties=Labels.range("energy", 1), + ) + block.add_gradient("strain", strain_gradient_block) + + energy_layout = TensorMap( + keys=Labels.single(), + blocks=[block], + ) + return energy_layout diff --git a/src/metatrain/cli/train.py b/src/metatrain/cli/train.py index 3325560ae..d941fa8aa 100644 --- a/src/metatrain/cli/train.py +++ b/src/metatrain/cli/train.py @@ -21,7 +21,7 @@ ) from ..utils.data import ( DatasetInfo, - TargetInfoDict, + TargetInfo, get_atomic_types, get_dataset, get_stats, @@ -227,11 +227,17 @@ def train_model( options["training_set"] = expand_dataset_config(options["training_set"]) train_datasets = [] - target_infos = TargetInfoDict() - for train_options in options["training_set"]: - dataset, target_info_dict = get_dataset(train_options) + target_info_dict: Dict[str, TargetInfo] = {} + for train_options in options["training_set"]: # loop over training sets + dataset, target_info_dict_single = get_dataset(train_options) train_datasets.append(dataset) - target_infos.update(target_info_dict) + intersecting_keys = target_info_dict.keys() & target_info_dict_single.keys() + for key in intersecting_keys: + if target_info_dict[key] != target_info_dict_single[key]: + raise ValueError( + f"Target information for key {key} differs between training sets." + ) + target_info_dict.update(target_info_dict_single) train_size = 1.0 @@ -320,7 +326,7 @@ def train_model( dataset_info = DatasetInfo( length_unit=options["training_set"][0]["systems"]["length_unit"], atomic_types=atomic_types, - targets=target_infos, + targets=target_info_dict, ) ########################### diff --git a/src/metatrain/experimental/alchemical_model/tests/test_exported.py b/src/metatrain/experimental/alchemical_model/tests/test_exported.py index 891983693..d4a1e6147 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_exported.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_exported.py @@ -3,11 +3,12 @@ from metatensor.torch.atomistic import ModelEvaluationOptions, System from metatrain.experimental.alchemical_model import AlchemicalModel -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) +from metatrain.utils.testing import energy_layout from . import MODEL_HYPERS @@ -22,7 +23,9 @@ def test_to(device, dtype): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = AlchemicalModel(MODEL_HYPERS, dataset_info).to(dtype=dtype) exported = model.export() diff --git a/src/metatrain/experimental/alchemical_model/tests/test_functionality.py b/src/metatrain/experimental/alchemical_model/tests/test_functionality.py index 7ee3331af..18acedc5e 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_functionality.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_functionality.py @@ -2,11 +2,12 @@ from metatensor.torch.atomistic import ModelEvaluationOptions, System from metatrain.experimental.alchemical_model import AlchemicalModel -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) +from metatrain.utils.testing import energy_layout from . import MODEL_HYPERS @@ -18,7 +19,9 @@ def test_prediction_subset_elements(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = AlchemicalModel(MODEL_HYPERS, dataset_info) diff --git a/src/metatrain/experimental/alchemical_model/tests/test_invariance.py b/src/metatrain/experimental/alchemical_model/tests/test_invariance.py index 9d9a84dd9..c8a42676d 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_invariance.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_invariance.py @@ -5,11 +5,12 @@ from metatensor.torch.atomistic import ModelEvaluationOptions, systems_to_torch from metatrain.experimental.alchemical_model import AlchemicalModel -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) +from metatrain.utils.testing import energy_layout from . import DATASET_PATH, MODEL_HYPERS @@ -20,7 +21,9 @@ def test_rotational_invariance(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = AlchemicalModel(MODEL_HYPERS, dataset_info) diff --git a/src/metatrain/experimental/alchemical_model/tests/test_regression.py b/src/metatrain/experimental/alchemical_model/tests/test_regression.py index 648c91c18..0d82379c3 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_regression.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_regression.py @@ -13,11 +13,11 @@ read_systems, read_targets, ) -from metatrain.utils.data.dataset import TargetInfoDict from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) +from metatrain.utils.testing import energy_layout from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS @@ -31,8 +31,8 @@ def test_regression_init(): """Perform a regression test on the model at initialization""" - targets = TargetInfoDict() - targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV") + targets = {} + targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV", layout=energy_layout) dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets diff --git a/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py b/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py index 03b7ef1df..fdcb9b8bf 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py @@ -14,8 +14,9 @@ from metatrain.experimental.alchemical_model.utils import ( systems_to_torch_alchemical_batch, ) -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict, read_systems +from metatrain.utils.data import DatasetInfo, TargetInfo, read_systems from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.testing import energy_layout from . import MODEL_HYPERS, QM9_DATASET_PATH @@ -70,7 +71,9 @@ def test_alchemical_model_inference(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=unique_numbers, - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) alchemical_model = AlchemicalModel(MODEL_HYPERS, dataset_info) diff --git a/src/metatrain/experimental/alchemical_model/tests/test_torchscript.py b/src/metatrain/experimental/alchemical_model/tests/test_torchscript.py index 33e0b9e9f..30d5511d7 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_torchscript.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_torchscript.py @@ -1,7 +1,8 @@ import torch from metatrain.experimental.alchemical_model import AlchemicalModel -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.testing import energy_layout from . import MODEL_HYPERS @@ -12,7 +13,9 @@ def test_torchscript(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = AlchemicalModel(MODEL_HYPERS, dataset_info) @@ -25,7 +28,9 @@ def test_torchscript_save_load(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = AlchemicalModel(MODEL_HYPERS, dataset_info) torch.jit.save( diff --git a/src/metatrain/experimental/alchemical_model/trainer.py b/src/metatrain/experimental/alchemical_model/trainer.py index 3ed190c00..4127e7036 100644 --- a/src/metatrain/experimental/alchemical_model/trainer.py +++ b/src/metatrain/experimental/alchemical_model/trainer.py @@ -9,7 +9,6 @@ from ...utils.data import ( CombinedDataLoader, Dataset, - TargetInfoDict, check_datasets, collate_fn, get_all_targets, @@ -247,12 +246,7 @@ def train( predictions = evaluate_model( model, systems, - TargetInfoDict( - **{ - key: model.dataset_info.targets[key] - for key in targets.keys() - } - ), + {key: model.dataset_info.targets[key] for key in targets.keys()}, is_training=True, ) @@ -295,12 +289,7 @@ def train( predictions = evaluate_model( model, systems, - TargetInfoDict( - **{ - key: model.dataset_info.targets[key] - for key in targets.keys() - } - ), + {key: model.dataset_info.targets[key] for key in targets.keys()}, is_training=False, ) diff --git a/src/metatrain/experimental/gap/tests/test_errors.py b/src/metatrain/experimental/gap/tests/test_errors.py index b0359bd80..791e061c6 100644 --- a/src/metatrain/experimental/gap/tests/test_errors.py +++ b/src/metatrain/experimental/gap/tests/test_errors.py @@ -12,10 +12,10 @@ Dataset, DatasetInfo, TargetInfo, - TargetInfoDict, read_systems, read_targets, ) +from metatrain.utils.testing import energy_force_layout from . import DATASET_ETHANOL_PATH, DEFAULT_HYPERS @@ -61,9 +61,11 @@ def test_ethanol_regression_train_and_invariance(): hypers = copy.deepcopy(DEFAULT_HYPERS) hypers["model"]["krr"]["num_sparse_points"] = 30 - target_info_dict = TargetInfoDict( - energy=TargetInfo(quantity="energy", unit="kcal/mol", gradients=["positions"]) - ) + target_info_dict = { + "energy": TargetInfo( + quantity="energy", unit="kcal/mol", layout=energy_force_layout + ) + } dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict diff --git a/src/metatrain/experimental/gap/tests/test_regression.py b/src/metatrain/experimental/gap/tests/test_regression.py index 81212353c..a3d703e76 100644 --- a/src/metatrain/experimental/gap/tests/test_regression.py +++ b/src/metatrain/experimental/gap/tests/test_regression.py @@ -8,8 +8,9 @@ from omegaconf import OmegaConf from metatrain.experimental.gap import GAP, Trainer -from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo from metatrain.utils.data.readers import read_systems, read_targets +from metatrain.utils.testing import energy_force_layout, energy_layout from . import DATASET_ETHANOL_PATH, DATASET_PATH, DEFAULT_HYPERS @@ -25,8 +26,8 @@ def test_regression_init(): """Perform a regression test on the model at initialization""" - targets = TargetInfoDict() - targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV") + targets = {} + targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV", layout=energy_layout) dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets @@ -57,8 +58,10 @@ def test_regression_train_and_invariance(): targets, _ = read_targets(OmegaConf.create(conf)) dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) - target_info_dict = TargetInfoDict() - target_info_dict["mtt::U0"] = TargetInfo(quantity="energy", unit="eV") + target_info_dict = {} + target_info_dict["mtt::U0"] = TargetInfo( + quantity="energy", unit="eV", layout=energy_layout + ) dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict @@ -138,9 +141,11 @@ def test_ethanol_regression_train_and_invariance(): hypers = copy.deepcopy(DEFAULT_HYPERS) hypers["model"]["krr"]["num_sparse_points"] = 900 - target_info_dict = TargetInfoDict( - energy=TargetInfo(quantity="energy", unit="kcal/mol", gradients=["positions"]) - ) + target_info_dict = { + "energy": TargetInfo( + quantity="energy", unit="kcal/mol", layout=energy_force_layout + ) + } dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict diff --git a/src/metatrain/experimental/gap/tests/test_torchscript.py b/src/metatrain/experimental/gap/tests/test_torchscript.py index 967a83353..e7d1cbc2f 100644 --- a/src/metatrain/experimental/gap/tests/test_torchscript.py +++ b/src/metatrain/experimental/gap/tests/test_torchscript.py @@ -2,8 +2,9 @@ from omegaconf import OmegaConf from metatrain.experimental.gap import GAP, Trainer -from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo from metatrain.utils.data.readers import read_systems, read_targets +from metatrain.utils.testing import energy_layout from . import DATASET_PATH, DEFAULT_HYPERS @@ -13,8 +14,10 @@ def test_torchscript(): """Tests that the model can be jitted.""" - target_info_dict = TargetInfoDict() - target_info_dict["mtt::U0"] = TargetInfo(quantity="energy", unit="eV") + target_info_dict = {} + target_info_dict["mtt::U0"] = TargetInfo( + quantity="energy", unit="eV", layout=energy_layout + ) dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict @@ -34,8 +37,6 @@ def test_torchscript(): targets, _ = read_targets(OmegaConf.create(conf)) systems = read_systems(DATASET_PATH) - # for system in systems: - # system.types = torch.ones(len(system.types), dtype=torch.int32) dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) hypers = DEFAULT_HYPERS.copy() @@ -64,8 +65,8 @@ def test_torchscript(): def test_torchscript_save(): """Tests that the model can be jitted and saved.""" - targets = TargetInfoDict() - targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV") + targets = {} + targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV", layout=energy_layout) dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets diff --git a/src/metatrain/experimental/pet/tests/test_exported.py b/src/metatrain/experimental/pet/tests/test_exported.py index f67a15e4c..3267c2992 100644 --- a/src/metatrain/experimental/pet/tests/test_exported.py +++ b/src/metatrain/experimental/pet/tests/test_exported.py @@ -11,12 +11,13 @@ from metatrain.experimental.pet import PET as WrappedPET from metatrain.utils.architectures import get_default_hypers -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo from metatrain.utils.export import export from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) +from metatrain.utils.testing import energy_layout DEFAULT_HYPERS = get_default_hypers("experimental.pet") @@ -32,7 +33,9 @@ def test_to(device): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) ARCHITECTURAL_HYPERS = Hypers(model.hypers) diff --git a/src/metatrain/experimental/pet/tests/test_functionality.py b/src/metatrain/experimental/pet/tests/test_functionality.py index ddf527603..4ca004a91 100644 --- a/src/metatrain/experimental/pet/tests/test_functionality.py +++ b/src/metatrain/experimental/pet/tests/test_functionality.py @@ -18,12 +18,13 @@ from metatrain.experimental.pet import PET as WrappedPET from metatrain.utils.architectures import get_default_hypers -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo from metatrain.utils.jsonschema import validate from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) +from metatrain.utils.testing import energy_layout DEFAULT_HYPERS = get_default_hypers("experimental.pet") @@ -65,7 +66,9 @@ def test_prediction(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) ARCHITECTURAL_HYPERS = Hypers(model.hypers) @@ -115,7 +118,9 @@ def test_per_atom_predictions_functionality(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) ARCHITECTURAL_HYPERS = Hypers(model.hypers) @@ -166,7 +171,9 @@ def test_selected_atoms_functionality(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) ARCHITECTURAL_HYPERS = Hypers(model.hypers) diff --git a/src/metatrain/experimental/pet/tests/test_pet_compatibility.py b/src/metatrain/experimental/pet/tests/test_pet_compatibility.py index 04ecae1bc..205e4e25b 100644 --- a/src/metatrain/experimental/pet/tests/test_pet_compatibility.py +++ b/src/metatrain/experimental/pet/tests/test_pet_compatibility.py @@ -17,8 +17,9 @@ from metatrain.experimental.pet import PET as WrappedPET from metatrain.experimental.pet.utils import systems_to_batch_dict from metatrain.utils.architectures import get_default_hypers -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.testing import energy_layout from . import DATASET_PATH @@ -97,7 +98,9 @@ def test_predictions_compatibility(cutoff): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=structure.numbers, - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) capabilities = ModelCapabilities( length_unit="Angstrom", diff --git a/src/metatrain/experimental/pet/tests/test_torchscript.py b/src/metatrain/experimental/pet/tests/test_torchscript.py index df0584cd3..15adc95a8 100644 --- a/src/metatrain/experimental/pet/tests/test_torchscript.py +++ b/src/metatrain/experimental/pet/tests/test_torchscript.py @@ -4,7 +4,8 @@ from metatrain.experimental.pet import PET as WrappedPET from metatrain.utils.architectures import get_default_hypers -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.testing import energy_layout DEFAULT_HYPERS = get_default_hypers("experimental.pet") @@ -16,7 +17,9 @@ def test_torchscript(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) ARCHITECTURAL_HYPERS = Hypers(model.hypers) @@ -31,7 +34,9 @@ def test_torchscript_save_load(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) ARCHITECTURAL_HYPERS = Hypers(model.hypers) diff --git a/src/metatrain/experimental/soap_bpnn/model.py b/src/metatrain/experimental/soap_bpnn/model.py index 556f3ef52..2684e34db 100644 --- a/src/metatrain/experimental/soap_bpnn/model.py +++ b/src/metatrain/experimental/soap_bpnn/model.py @@ -204,7 +204,11 @@ def restart(self, dataset_info: DatasetInfo) -> "SoapBpnn": new_atomic_types = [ at for at in merged_info.atomic_types if at not in self.atomic_types ] - new_targets = merged_info.targets - self.dataset_info.targets + new_targets = { + key: value + for key, value in merged_info.targets.items() + if key not in self.dataset_info.targets + } if len(new_atomic_types) > 0: raise ValueError( diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_continue.py b/src/metatrain/experimental/soap_bpnn/tests/test_continue.py index fc732897e..50fa7118a 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_continue.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_continue.py @@ -6,8 +6,9 @@ from omegaconf import OmegaConf from metatrain.experimental.soap_bpnn import SoapBpnn, Trainer -from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo from metatrain.utils.data.readers import read_systems, read_targets +from metatrain.utils.testing import energy_layout from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS @@ -22,8 +23,10 @@ def test_continue(monkeypatch, tmp_path): systems = read_systems(DATASET_PATH) systems = [system.to(torch.float32) for system in systems] - target_info_dict = TargetInfoDict() - target_info_dict["mtt::U0"] = TargetInfo(quantity="energy", unit="eV") + target_info_dict = {} + target_info_dict["mtt::U0"] = TargetInfo( + quantity="energy", unit="eV", layout=energy_layout + ) dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_exported.py b/src/metatrain/experimental/soap_bpnn/tests/test_exported.py index 63242161e..f23e7b3c8 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_exported.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_exported.py @@ -3,11 +3,12 @@ from metatensor.torch.atomistic import ModelEvaluationOptions, System from metatrain.experimental.soap_bpnn import SoapBpnn -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) +from metatrain.utils.testing import energy_layout from . import MODEL_HYPERS @@ -22,7 +23,9 @@ def test_to(device, dtype): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = SoapBpnn(MODEL_HYPERS, dataset_info).to(dtype=dtype) exported = model.export() diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py b/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py index 25dd250b6..de5f5e3bb 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py @@ -7,7 +7,8 @@ from metatrain.experimental.soap_bpnn import SoapBpnn from metatrain.utils.architectures import check_architecture_options -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.testing import energy_layout from . import DEFAULT_HYPERS, MODEL_HYPERS @@ -19,7 +20,9 @@ def test_prediction_subset_elements(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = SoapBpnn(MODEL_HYPERS, dataset_info) @@ -42,7 +45,9 @@ def test_prediction_subset_atoms(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = SoapBpnn(MODEL_HYPERS, dataset_info) @@ -108,7 +113,9 @@ def test_output_last_layer_features(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = SoapBpnn(MODEL_HYPERS, dataset_info) @@ -179,7 +186,9 @@ def test_output_per_atom(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = SoapBpnn(MODEL_HYPERS, dataset_info) diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_invariance.py b/src/metatrain/experimental/soap_bpnn/tests/test_invariance.py index 2b5835b74..92c96767d 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_invariance.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_invariance.py @@ -5,7 +5,8 @@ from metatensor.torch.atomistic import systems_to_torch from metatrain.experimental.soap_bpnn import SoapBpnn -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.testing import energy_layout from . import DATASET_PATH, MODEL_HYPERS @@ -16,7 +17,9 @@ def test_rotational_invariance(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = SoapBpnn(MODEL_HYPERS, dataset_info) diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_regression.py b/src/metatrain/experimental/soap_bpnn/tests/test_regression.py index 7b4161ddb..35f691051 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_regression.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_regression.py @@ -6,8 +6,9 @@ from omegaconf import OmegaConf from metatrain.experimental.soap_bpnn import SoapBpnn, Trainer -from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo from metatrain.utils.data.readers import read_systems, read_targets +from metatrain.utils.testing import energy_layout from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS @@ -21,8 +22,8 @@ def test_regression_init(): """Perform a regression test on the model at initialization""" - targets = TargetInfoDict() - targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV") + targets = {} + targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV", layout=energy_layout) dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py b/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py index 4d6b89898..d35f8deee 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py @@ -4,7 +4,8 @@ from metatensor.torch.atomistic import System from metatrain.experimental.soap_bpnn import SoapBpnn -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.testing import energy_layout from . import MODEL_HYPERS @@ -15,7 +16,9 @@ def test_torchscript(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = SoapBpnn(MODEL_HYPERS, dataset_info) model = torch.jit.script(model) @@ -39,7 +42,9 @@ def test_torchscript_with_identity(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) hypers = copy.deepcopy(MODEL_HYPERS) hypers["bpnn"]["layernorm"] = False @@ -65,7 +70,9 @@ def test_torchscript_save_load(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = SoapBpnn(MODEL_HYPERS, dataset_info) torch.jit.save( diff --git a/src/metatrain/experimental/soap_bpnn/trainer.py b/src/metatrain/experimental/soap_bpnn/trainer.py index aed858bcd..810d9e471 100644 --- a/src/metatrain/experimental/soap_bpnn/trainer.py +++ b/src/metatrain/experimental/soap_bpnn/trainer.py @@ -8,8 +8,7 @@ from torch.utils.data import DataLoader, DistributedSampler from ...utils.additive import remove_additive -from ...utils.data import CombinedDataLoader, Dataset, TargetInfoDict, collate_fn -from ...utils.data.extract_targets import get_targets_dict +from ...utils.data import CombinedDataLoader, Dataset, collate_fn from ...utils.distributed.distributed_data_parallel import DistributedDataParallel from ...utils.distributed.slurm import DistributedEnvironment from ...utils.evaluate_model import evaluate_model @@ -182,9 +181,7 @@ def train( val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) # Extract all the possible outputs and their gradients: - train_targets = get_targets_dict( - train_datasets, (model.module if is_distributed else model).dataset_info - ) + train_targets = (model.module if is_distributed else model).dataset_info.targets outputs_list = [] for target_name, target_info in train_targets.items(): outputs_list.append(target_name) @@ -270,9 +267,7 @@ def train( predictions = evaluate_model( model, systems, - TargetInfoDict( - **{key: train_targets[key] for key in targets.keys()} - ), + {key: train_targets[key] for key in targets.keys()}, is_training=True, ) @@ -325,9 +320,7 @@ def train( predictions = evaluate_model( model, systems, - TargetInfoDict( - **{key: train_targets[key] for key in targets.keys()} - ), + {key: train_targets[key] for key in targets.keys()}, is_training=False, ) diff --git a/src/metatrain/utils/additive/remove.py b/src/metatrain/utils/additive/remove.py index 4235af1fc..899a96a07 100644 --- a/src/metatrain/utils/additive/remove.py +++ b/src/metatrain/utils/additive/remove.py @@ -1,12 +1,12 @@ import warnings -from typing import Dict, List, Union +from typing import Dict, List import metatensor.torch import torch from metatensor.torch import TensorMap from metatensor.torch.atomistic import System -from ..data import TargetInfo, TargetInfoDict +from ..data import TargetInfo from ..evaluate_model import evaluate_model @@ -14,7 +14,7 @@ def remove_additive( systems: List[System], targets: Dict[str, TensorMap], additive_model: torch.nn.Module, - target_info_dict: Union[Dict[str, TargetInfo], TargetInfoDict], + target_info_dict: Dict[str, TargetInfo], ): """Remove an additive contribution from the training targets. @@ -35,7 +35,7 @@ def remove_additive( additive_contribution = evaluate_model( additive_model, systems, - TargetInfoDict(**{key: target_info_dict[key] for key in targets.keys()}), + {key: target_info_dict[key] for key in targets.keys()}, is_training=False, # we don't need any gradients w.r.t. any parameters ) diff --git a/src/metatrain/utils/data/__init__.py b/src/metatrain/utils/data/__init__.py index 0aa128433..e93621947 100644 --- a/src/metatrain/utils/data/__init__.py +++ b/src/metatrain/utils/data/__init__.py @@ -1,7 +1,6 @@ from .dataset import ( # noqa: F401 Dataset, TargetInfo, - TargetInfoDict, DatasetInfo, get_atomic_types, get_all_targets, @@ -21,5 +20,4 @@ from .writers import write_predictions # noqa: F401 from .combine_dataloaders import CombinedDataLoader # noqa: F401 from .system_to_ase import system_to_ase # noqa: F401 -from .extract_targets import get_targets_dict # noqa: F401 from .get_dataset import get_dataset # noqa: F401 diff --git a/src/metatrain/utils/data/dataset.py b/src/metatrain/utils/data/dataset.py index cd6a68ab2..f9761b129 100644 --- a/src/metatrain/utils/data/dataset.py +++ b/src/metatrain/utils/data/dataset.py @@ -1,8 +1,8 @@ import math import warnings -from collections import UserDict -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Tuple, Union +import metatensor.torch import numpy as np from metatensor.learn.data import Dataset, group_and_join from metatensor.torch import TensorMap @@ -15,40 +15,49 @@ class TargetInfo: """A class that contains information about a target. - :param quantity: The quantity of the target. + :param quantity: The physical quantity of the target (e.g., "energy"). + :param layout: The layout of the target, as a ``TensorMap`` with 0 samples. + This ``TensorMap`` will contain important information such as the names of + the ``samples``, as well as the ``components`` and ``properties`` of the + target and their gradients. :param unit: The unit of the target. If :py:obj:`None` the ``unit`` will be set to an empty string ``""``. - :param per_atom: Whether the target is a per-atom quantity. - :param gradients: List containing the gradient names of the target that are present - in the target. Examples are ``"positions"`` or ``"strain"``. ``gradients`` will - be stored as a sorted list of **unique** gradients. """ def __init__( self, quantity: str, + layout: TensorMap, unit: Union[None, str] = "", - per_atom: bool = False, - gradients: Optional[List[str]] = None, ): + # one of these will be set to True inside the _check_layout method + self.is_scalar = False + self.is_cartesian = False + self.is_spherical = False + + self._check_layout(layout) + self.quantity = quantity + self.layout = layout self.unit = unit if unit is not None else "" - self.per_atom = per_atom - self._gradients = set(gradients) if gradients is not None else set() @property def gradients(self) -> List[str]: """Sorted and unique list of gradient names.""" - return sorted(self._gradients) + if self.is_scalar: + return sorted(self.layout.block().gradients_list()) + else: + return [] - @gradients.setter - def gradients(self, value: List[str]): - self._gradients = set(value) + @property + def per_atom(self) -> bool: + """Whether the target is per atom.""" + return "atom" in self.layout.block(0).samples.names def __repr__(self): return ( f"TargetInfo(quantity={self.quantity!r}, unit={self.unit!r}, " - f"per_atom={self.per_atom!r}, gradients={self.gradients!r})" + f"layout={self.layout!r})" ) def __eq__(self, other): @@ -60,109 +69,126 @@ def __eq__(self, other): return ( self.quantity == other.quantity and self.unit == other.unit - and self.per_atom == other.per_atom - and self._gradients == other._gradients + and metatensor.torch.equal(self.layout, other.layout) ) - def copy(self) -> "TargetInfo": - """Return a shallow copy of the TargetInfo.""" - return TargetInfo( - quantity=self.quantity, - unit=self.unit, - per_atom=self.per_atom, - gradients=self.gradients.copy(), - ) - - def update(self, other: "TargetInfo") -> None: - """Update this instance with the union of itself and ``other``. + def _check_layout(self, layout: TensorMap) -> None: + """Check that the layout is a valid layout.""" + + # examine basic properties of all blocks + for block in layout.blocks(): + for sample_name in block.samples.names: + if sample_name not in ["system", "atom"]: + raise ValueError( + "The layout ``TensorMap`` of a target should only have samples " + "named 'system' or 'atom', but found " + f"'{sample_name}' instead." + ) + if len(block.values) != 0: + raise ValueError( + "The layout ``TensorMap`` of a target should have 0 " + f"samples, but found {len(block.values)} samples." + ) - :raises ValueError: If ``quantity``, ``unit`` or ``per_atom`` do not match. - """ - if self.quantity != other.quantity: - raise ValueError( - f"Can't update TargetInfo with a different `quantity`: " - f"({self.quantity} != {other.quantity})" - ) + # examine the components of the first block to decide whether this is + # a scalar, a Cartesian tensor or a spherical tensor - if self.unit != other.unit: + if len(layout) == 0: raise ValueError( - f"Can't update TargetInfo with a different `unit`: " - f"({self.unit} != {other.unit})" + "The layout ``TensorMap`` of a target should have at least one " + "block, but found 0 blocks." ) - - if self.per_atom != other.per_atom: - raise ValueError( - f"Can't update TargetInfo with a different `per_atom` property: " - f"({self.per_atom} != {other.per_atom})" - ) - - self.gradients = self.gradients + other.gradients - - def union(self, other: "TargetInfo") -> "TargetInfo": - """Return the union of this instance with ``other``.""" - new = self.copy() - new.update(other) - return new - - -class TargetInfoDict(UserDict): - """ - A custom dictionary class for storing and managing ``TargetInfo`` instances. - - The subclass handles the update of :py:class:`TargetInfo` if a ``key`` is already - present. - """ - - # We use a `UserDict` with special methods because a normal dict does not support - # the update of nested instances. - def __setitem__(self, key, value): - if not isinstance(value, TargetInfo): - raise ValueError("value to set is not a `TargetInfo` instance") - if key in self: - self[key].update(value) - else: - super().__setitem__(key, value) - - def __and__(self, other: "TargetInfoDict") -> "TargetInfoDict": - return self.intersection(other) - - def __sub__(self, other: "TargetInfoDict") -> "TargetInfoDict": - return self.difference(other) - - def union(self, other: "TargetInfoDict") -> "TargetInfoDict": - """Union of this instance with ``other``.""" - new = self.copy() - new.update(other) - return new - - def intersection(self, other: "TargetInfoDict") -> "TargetInfoDict": - """Intersection of the the two instances as a new ``TargetInfoDict``. - - (i.e. all elements that are in both sets.) - - :raises ValueError: If intersected items with the same key are not the same. - """ - new_keys = self.keys() & other.keys() - - self_intersect = TargetInfoDict(**{key: self[key] for key in new_keys}) - other_intersect = TargetInfoDict(**{key: other[key] for key in new_keys}) - - if self_intersect == other_intersect: - return self_intersect + components_first_block = layout.block(0).components + if len(components_first_block) == 0: + self.is_scalar = True + elif components_first_block[0].names[0].startswith("xyz"): + self.is_cartesian = True + elif ( + len(components_first_block) == 1 + and components_first_block[0].names[0] == "o3_mu" + ): + self.is_spherical = True else: raise ValueError( - "Intersected items with the same key are not the same. Intersected " - f"keys are {','.join(new_keys)}" + "The layout ``TensorMap`` of a target should be " + "either scalars, Cartesian tensors or spherical tensors. The type of " + "the target could not be determined." ) - def difference(self, other: "TargetInfoDict") -> "TargetInfoDict": - """Difference of two instances as a new ``TargetInfoDict``. - - (i.e. all elements that are in this set but not in the other.) - """ + if self.is_scalar: + if layout.keys.names != ["_"]: + raise ValueError( + "The layout ``TensorMap`` of a scalar target should have " + "a single key sample named '_'." + ) + if len(layout.blocks()) != 1: + raise ValueError( + "The layout ``TensorMap`` of a scalar target should have " + "a single block." + ) + gradients_names = layout.block(0).gradients_list() + for gradient_name in gradients_names: + if gradient_name not in ["positions", "strain"]: + raise ValueError( + "Only `positions` and `strain` gradients are supported for " + "scalar targets. " + f"Found '{gradient_name}' instead." + ) + if self.is_cartesian: + if layout.keys.names != ["_"]: + raise ValueError( + "The layout ``TensorMap`` of a Cartesian tensor target should have " + "a single key sample named '_'." + ) + if len(layout.blocks()) != 1: + raise ValueError( + "The layout ``TensorMap`` of a Cartesian tensor target should have " + "a single block." + ) + if len(layout.block(0).gradients_list()) > 0: + raise ValueError( + "Gradients of Cartesian tensor targets are not supported." + ) - new_keys = self.keys() - other.keys() - return TargetInfoDict(**{key: self[key] for key in new_keys}) + if self.is_spherical: + if layout.keys.names != ["o3_lambda", "o3_sigma"]: + raise ValueError( + "The layout ``TensorMap`` of a spherical tensor target " + "should have two keys named 'o3_lambda' and 'o3_sigma'." + f"Found '{layout.keys.names}' instead." + ) + for key, block in layout.items(): + o3_lambda, o3_sigma = int(key.values[0].item()), int( + key.values[1].item() + ) + if o3_sigma not in [-1, 1]: + raise ValueError( + "The layout ``TensorMap`` of a spherical tensor target should " + "have a key sample 'o3_sigma' that is either -1 or 1." + f"Found '{o3_sigma}' instead." + ) + if o3_lambda < 0: + raise ValueError( + "The layout ``TensorMap`` of a spherical tensor target should " + "have a key sample 'o3_lambda' that is non-negative." + f"Found '{o3_lambda}' instead." + ) + components = block.components + if len(components) != 1: + raise ValueError( + "The layout ``TensorMap`` of a spherical tensor target should " + "have a single component." + ) + if len(components[0]) != 2 * o3_lambda + 1: + raise ValueError( + "Each ``TensorBlock`` of a spherical tensor target should have " + "a component with 2*o3_lambda + 1 elements." + f"Found '{len(components[0])}' elements instead." + ) + if len(block.gradients_list()) > 0: + raise ValueError( + "Gradients of spherical tensor targets are not supported." + ) class DatasetInfo: @@ -180,7 +206,7 @@ class DatasetInfo: """ def __init__( - self, length_unit: str, atomic_types: List[int], targets: TargetInfoDict + self, length_unit: str, atomic_types: List[int], targets: Dict[str, TargetInfo] ): self.length_unit = length_unit if length_unit is not None else "" self._atomic_types = set(atomic_types) @@ -233,6 +259,14 @@ def update(self, other: "DatasetInfo") -> None: ) self.atomic_types = self.atomic_types + other.atomic_types + + intersecting_target_keys = self.targets.keys() & other.targets.keys() + for key in intersecting_target_keys: + if self.targets[key] != other.targets[key]: + raise ValueError( + f"Can't update DatasetInfo with different target information for " + f"target '{key}': {self.targets[key]} != {other.targets[key]}" + ) self.targets.update(other.targets) def union(self, other: "DatasetInfo") -> "DatasetInfo": diff --git a/src/metatrain/utils/data/extract_targets.py b/src/metatrain/utils/data/extract_targets.py deleted file mode 100644 index ee86b29d4..000000000 --- a/src/metatrain/utils/data/extract_targets.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import Dict, List, Union - -import torch - -from metatrain.utils.data import Dataset - -from .dataset import DatasetInfo, TargetInfo - - -def get_targets_dict( - datasets: List[Union[Dataset, torch.utils.data.Subset]], dataset_info: DatasetInfo -) -> Dict[str, TargetInfo]: - """ - This is a helper function that extracts all the possible targets and their - gradients from a list of datasets. - - :param datasets: A list of Datasets or Subsets. - :param dataset_info: A DatasetInfo object containing further - information about the dataset, namely the unit and quantity of the - targets. - - :returns: A dictionary mapping target names to ``TargetInfo`` objects. - - :raises ValueError: If the ``DatasetInfo`` object does not contain any of - the expected targets. - """ - - targets_dict = {} - for dataset in datasets: - targets = next(iter(dataset)) - targets = targets._asdict() - targets.pop("system") # system not needed - - # targets is now a dictionary of TensorMaps - for target_name, target_tmap in targets.items(): - if target_name not in dataset_info.targets.keys(): - raise ValueError( - f"Target {target_name} not found in the targets " - "specified in dataset_info." - ) - if target_name not in targets_dict: - targets_dict[target_name] = TargetInfo( - quantity=dataset_info.targets[target_name].quantity, - unit=dataset_info.targets[target_name].unit, - gradients=target_tmap.block(0).gradients_list(), - ) - - return targets_dict diff --git a/src/metatrain/utils/data/get_dataset.py b/src/metatrain/utils/data/get_dataset.py index 2f95263c5..502aea40f 100644 --- a/src/metatrain/utils/data/get_dataset.py +++ b/src/metatrain/utils/data/get_dataset.py @@ -1,12 +1,12 @@ -from typing import Tuple +from typing import Dict, Tuple from omegaconf import DictConfig -from .dataset import Dataset, TargetInfoDict +from .dataset import Dataset, TargetInfo from .readers import read_systems, read_targets -def get_dataset(options: DictConfig) -> Tuple[Dataset, TargetInfoDict]: +def get_dataset(options: DictConfig) -> Tuple[Dataset, Dict[str, TargetInfo]]: """ Gets a dataset given a configuration dictionary. @@ -18,7 +18,7 @@ def get_dataset(options: DictConfig) -> Tuple[Dataset, TargetInfoDict]: systems and targets in the dataset. :returns: A tuple containing a ``Dataset`` object and a - ``TargetInfoDict`` containing additional information (units, + ``Dict[str, TargetInfo]`` containing additional information (units, physical quantities, ...) on the targets in the dataset """ diff --git a/src/metatrain/utils/data/readers/readers.py b/src/metatrain/utils/data/readers/readers.py index b095195b1..dc81af7c3 100644 --- a/src/metatrain/utils/data/readers/readers.py +++ b/src/metatrain/utils/data/readers/readers.py @@ -8,7 +8,7 @@ from metatensor.torch.atomistic import System from omegaconf import DictConfig -from ..dataset import TargetInfo, TargetInfoDict +from ..dataset import TargetInfo logger = logging.getLogger(__name__) @@ -161,7 +161,7 @@ def read_virial( def read_targets( conf: DictConfig, -) -> Tuple[Dict[str, List[TensorMap]], TargetInfoDict]: +) -> Tuple[Dict[str, List[TensorMap]], Dict[str, TargetInfo]]: """Reading all target information from a fully expanded config. To get such a config you can use :func:`expand_dataset_config @@ -175,9 +175,8 @@ def read_targets( :param conf: config containing the keys for what should be read. :returns: Dictionary containing a list of TensorMaps for each target section in the - config as well as a :py:class:`TargetInfoDict - ` instance containing the metadata of the - targets. + config as well as a ``Dict[str, TargetInfo]`` object + containing the metadata of the targets. :raises ValueError: if the target name is not valid. Valid target names are those that either start with ``mtt::`` or those that are in the list of @@ -185,7 +184,7 @@ def read_targets( https://docs.metatensor.org/latest/atomistic/outputs.html) """ target_dictionary = {} - target_info_dictionary = TargetInfoDict() + target_info_dictionary = {} standard_outputs_list = ["energy"] for target_key, target in conf.items(): @@ -288,8 +287,39 @@ def read_targets( target_info_dictionary[target_key] = TargetInfo( quantity=target["quantity"], unit=target["unit"], - per_atom=False, # TODO: read this from the config - gradients=target_info_gradients, + layout=_empty_tensor_map_like(target_dictionary[target_key][0]), ) return target_dictionary, target_info_dictionary + + +def _empty_tensor_map_like(tensor_map: TensorMap) -> TensorMap: + new_keys = tensor_map.keys + new_blocks: List[TensorBlock] = [] + for block in tensor_map.blocks(): + new_block = _empty_tensor_block_like(block) + new_blocks.append(new_block) + return TensorMap(keys=new_keys, blocks=new_blocks) + + +def _empty_tensor_block_like(tensor_block: TensorBlock) -> TensorBlock: + new_block = TensorBlock( + values=torch.empty( + (0,) + tensor_block.values.shape[1:], + dtype=tensor_block.values.dtype, + device=tensor_block.values.device, + ), + samples=Labels( + names=tensor_block.samples.names, + values=torch.empty( + (0, tensor_block.samples.values.shape[1]), + dtype=tensor_block.samples.values.dtype, + device=tensor_block.samples.values.device, + ), + ), + components=tensor_block.components, + properties=tensor_block.properties, + ) + for gradient_name, gradient in tensor_block.gradients(): + new_block.add_gradient(gradient_name, _empty_tensor_block_like(gradient)) + return new_block diff --git a/src/metatrain/utils/evaluate_model.py b/src/metatrain/utils/evaluate_model.py index 48447b879..71ad28686 100644 --- a/src/metatrain/utils/evaluate_model.py +++ b/src/metatrain/utils/evaluate_model.py @@ -12,7 +12,7 @@ register_autograd_neighbors, ) -from .data import TargetInfoDict +from .data import TargetInfo from .export import is_exported from .output_gradient import compute_gradient @@ -33,7 +33,7 @@ def evaluate_model( torch.jit._script.RecursiveScriptModule, ], systems: List[System], - targets: TargetInfoDict, + targets: Dict[str, TargetInfo], is_training: bool, check_consistency: bool = False, ) -> Dict[str, TensorMap]: @@ -234,7 +234,7 @@ def _get_model_outputs( torch.jit._script.RecursiveScriptModule, ], systems: List[System], - targets: TargetInfoDict, + targets: Dict[str, TargetInfo], check_consistency: bool, ) -> Dict[str, TensorMap]: if is_exported(model): diff --git a/src/metatrain/utils/llpr.py b/src/metatrain/utils/llpr.py index 164dd4267..6cfa42394 100644 --- a/src/metatrain/utils/llpr.py +++ b/src/metatrain/utils/llpr.py @@ -12,8 +12,7 @@ ) from torch.utils.data import DataLoader -from .data import DatasetInfo, TargetInfoDict, get_atomic_types -from .data.extract_targets import get_targets_dict +from .data import DatasetInfo, TargetInfo, get_atomic_types from .evaluate_model import evaluate_model from .per_atom import average_by_num_atoms @@ -261,7 +260,7 @@ class in ``metatrain``. def compute_covariance_as_pseudo_hessian( self, train_loader: DataLoader, - target_infos: TargetInfoDict, + target_infos: Dict[str, TargetInfo], loss_fn: Callable, parameters: List[torch.nn.Parameter], ) -> None: @@ -305,7 +304,7 @@ class in ``metatrain``. atomic_types=get_atomic_types(dataset), targets=target_infos, ) - train_targets = get_targets_dict([train_loader.dataset], dataset_info) + train_targets = dataset_info.targets device = self.covariance.device dtype = self.covariance.dtype for batch in train_loader: @@ -318,7 +317,7 @@ class in ``metatrain``. predictions = evaluate_model( self.model, systems, - TargetInfoDict(**{key: train_targets[key] for key in targets.keys()}), + {key: train_targets[key] for key in targets.keys()}, is_training=True, # keep the computational graph ) diff --git a/src/metatrain/utils/metrics.py b/src/metatrain/utils/metrics.py index 357b567d0..29ace5104 100644 --- a/src/metatrain/utils/metrics.py +++ b/src/metatrain/utils/metrics.py @@ -8,7 +8,7 @@ class RMSEAccumulator: """Accumulates the RMSE between predictions and targets for an arbitrary number of keys, each corresponding to one target.""" - def __init__(self): + def __init__(self) -> None: """Initialize the accumulator.""" self.information: Dict[str, Tuple[float, int]] = {} @@ -91,7 +91,7 @@ class MAEAccumulator: """Accumulates the MAE between predictions and targets for an arbitrary number of keys, each corresponding to one target.""" - def __init__(self): + def __init__(self) -> None: """Initialize the accumulator.""" self.information: Dict[str, Tuple[float, int]] = {} diff --git a/src/metatrain/utils/testing.py b/src/metatrain/utils/testing.py new file mode 100644 index 000000000..62dc7382c --- /dev/null +++ b/src/metatrain/utils/testing.py @@ -0,0 +1,73 @@ +# This file contains some example TensorMap layouts that can be +# used for testing purposes. + +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + + +block = TensorBlock( + values=torch.empty(0, 1), + samples=Labels( + names=["system"], + values=torch.empty((0, 1), dtype=torch.int32), + ), + components=[], + properties=Labels.range("energy", 1), +) +energy_layout = TensorMap( + keys=Labels.single(), + blocks=[block], +) + +block_with_position_gradients = block.copy() +position_gradient_block = TensorBlock( + values=torch.empty(0, 3, 1), + samples=Labels( + names=["sample", "atom"], + values=torch.empty((0, 2), dtype=torch.int32), + ), + components=[ + Labels( + names=["xyz"], + values=torch.arange(3, dtype=torch.int32).reshape(-1, 1), + ), + ], + properties=Labels.range("energy", 1), +) +block_with_position_gradients.add_gradient("positions", position_gradient_block) +energy_force_layout = TensorMap( + keys=Labels.single(), + blocks=[block_with_position_gradients], +) + +block_with_position_and_strain_gradients = block_with_position_gradients.copy() +strain_gradient_block = TensorBlock( + values=torch.empty(0, 3, 3, 1), + samples=Labels( + names=["sample", "atom"], + values=torch.empty((0, 2), dtype=torch.int32), + ), + components=[ + Labels( + names=["xyz_1"], + values=torch.arange(3, dtype=torch.int32).reshape(-1, 1), + ), + Labels( + names=["xyz_2"], + values=torch.arange(3, dtype=torch.int32).reshape(-1, 1), + ), + ], + properties=Labels.range("energy", 1), +) +block_with_position_and_strain_gradients.add_gradient("strain", strain_gradient_block) +energy_force_stress_layout = TensorMap( + keys=Labels.single(), + blocks=[block_with_position_and_strain_gradients], +) + +block_with_strain_gradients = block.copy() +block_with_strain_gradients.add_gradient("strain", strain_gradient_block) +energy_stress_layout = TensorMap( + keys=Labels.single(), + blocks=[block_with_strain_gradients], +) diff --git a/tests/cli/test_eval_model.py b/tests/cli/test_eval_model.py index 6510deee9..6da30fb5c 100644 --- a/tests/cli/test_eval_model.py +++ b/tests/cli/test_eval_model.py @@ -11,6 +11,7 @@ from metatrain.cli.eval import eval_model from metatrain.experimental.soap_bpnn import __model__ from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.testing import energy_layout from . import EVAL_OPTIONS_PATH, MODEL_HYPERS, MODEL_PATH, RESOURCES_PATH @@ -84,9 +85,7 @@ def test_eval_export(monkeypatch, tmp_path, options): length_unit="angstrom", atomic_types={1, 6, 7, 8}, targets={ - "energy": TargetInfo( - quantity="energy", unit="eV", per_atom=False, gradients=[] - ) + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) }, ) model = __model__(model_hypers=MODEL_HYPERS, dataset_info=dataset_info) diff --git a/tests/cli/test_export_model.py b/tests/cli/test_export_model.py index a3e62e9d7..ca93a7b00 100644 --- a/tests/cli/test_export_model.py +++ b/tests/cli/test_export_model.py @@ -15,6 +15,7 @@ from metatrain.experimental.soap_bpnn import __model__ from metatrain.utils.architectures import find_all_architectures from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.testing import energy_layout from . import MODEL_HYPERS, RESOURCES_PATH @@ -28,9 +29,7 @@ def test_export(monkeypatch, tmp_path, path): length_unit="angstrom", atomic_types={1}, targets={ - "energy": TargetInfo( - quantity="energy", unit="eV", per_atom=False, gradients=[] - ) + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) }, ) model = __model__(model_hypers=MODEL_HYPERS, dataset_info=dataset_info) @@ -93,9 +92,7 @@ def test_reexport(monkeypatch, tmp_path): length_unit="angstrom", atomic_types={1, 6, 7, 8}, targets={ - "energy": TargetInfo( - quantity="energy", unit="eV", per_atom=False, gradients=[] - ) + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) }, ) model = __model__(model_hypers=MODEL_HYPERS, dataset_info=dataset_info) diff --git a/tests/cli/test_train_model.py b/tests/cli/test_train_model.py index 5e33c5c7e..dee3da5b8 100644 --- a/tests/cli/test_train_model.py +++ b/tests/cli/test_train_model.py @@ -473,6 +473,7 @@ def test_continue_different_dataset(options, monkeypatch, tmp_path): options["training_set"]["systems"]["read_from"] = "ethanol_reduced_100.xyz" options["training_set"]["targets"]["energy"]["key"] = "energy" + options["training_set"]["targets"]["energy"]["forces"] = False train_model(options, continue_from=MODEL_PATH_64_BIT) diff --git a/tests/utils/data/test_dataset.py b/tests/utils/data/test_dataset.py index 20100aa3d..9c76c5a7b 100644 --- a/tests/utils/data/test_dataset.py +++ b/tests/utils/data/test_dataset.py @@ -2,13 +2,13 @@ import pytest import torch +from metatensor.torch import Labels, TensorBlock, TensorMap from omegaconf import OmegaConf from metatrain.utils.data import ( Dataset, DatasetInfo, TargetInfo, - TargetInfoDict, check_datasets, collate_fn, get_all_targets, @@ -22,81 +22,159 @@ RESOURCES_PATH = Path(__file__).parents[2] / "resources" -def test_target_info_default(): - target_info = TargetInfo(quantity="energy", unit="kcal/mol") +@pytest.fixture +def layout_scalar(): + return TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.empty(0, 1), + samples=Labels( + names=["system"], + values=torch.empty((0, 1), dtype=torch.int32), + ), + components=[], + properties=Labels.range("energy", 1), + ) + ], + ) + + +@pytest.fixture +def layout_spherical(): + return TensorMap( + keys=Labels( + names=["o3_lambda", "o3_sigma"], + values=torch.tensor([[0, 1], [2, 1]]), + ), + blocks=[ + TensorBlock( + values=torch.empty(0, 1, 1), + samples=Labels( + names=["system"], + values=torch.empty((0, 1), dtype=torch.int32), + ), + components=[ + Labels( + names=["o3_mu"], + values=torch.arange(0, 1, dtype=torch.int32).reshape(-1, 1), + ), + ], + properties=Labels.single(), + ), + TensorBlock( + values=torch.empty(0, 5, 1), + samples=Labels( + names=["system"], + values=torch.empty((0, 1), dtype=torch.int32), + ), + components=[ + Labels( + names=["o3_mu"], + values=torch.arange(-2, 3, dtype=torch.int32).reshape(-1, 1), + ), + ], + properties=Labels.single(), + ), + ], + ) + + +@pytest.fixture +def layout_cartesian(): + return TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.empty(0, 3, 3, 1), + samples=Labels( + names=["system"], + values=torch.empty((0, 1), dtype=torch.int32), + ), + components=[ + Labels( + names=["xyz_1"], + values=torch.arange(0, 3, dtype=torch.int32).reshape(-1, 1), + ), + Labels( + names=["xyz_2"], + values=torch.arange(0, 3, dtype=torch.int32).reshape(-1, 1), + ), + ], + properties=Labels.single(), + ), + ], + ) + + +def test_target_info_scalar(layout_scalar): + target_info = TargetInfo(quantity="energy", unit="kcal/mol", layout=layout_scalar) assert target_info.quantity == "energy" assert target_info.unit == "kcal/mol" - assert target_info.per_atom is False assert target_info.gradients == [] + assert not target_info.per_atom - expected = ( - "TargetInfo(quantity='energy', unit='kcal/mol', per_atom=False, gradients=[])" - ) - assert target_info.__repr__() == expected + expected_start = "TargetInfo(quantity='energy', unit='kcal/mol'" + assert target_info.__repr__()[: len(expected_start)] == expected_start -def test_target_info_gradients(): +def test_target_info_spherical(layout_spherical): target_info = TargetInfo( - quantity="energy", - unit="kcal/mol", - per_atom=True, - gradients=["positions", "positions"], + quantity="mtt::spherical", unit="kcal/mol", layout=layout_spherical ) - assert target_info.quantity == "energy" + assert target_info.quantity == "mtt::spherical" assert target_info.unit == "kcal/mol" - assert target_info.per_atom is True - assert target_info.gradients == ["positions"] + assert target_info.gradients == [] + assert not target_info.per_atom - expected = ( - "TargetInfo(quantity='energy', unit='kcal/mol', per_atom=True, " - "gradients=['positions'])" - ) - assert target_info.__repr__() == expected + expected_start = "TargetInfo(quantity='mtt::spherical', unit='kcal/mol'" + assert target_info.__repr__()[: len(expected_start)] == expected_start -def test_list_gradients(): - info1 = TargetInfo(quantity="energy", unit="eV") +def test_target_info_cartesian(layout_cartesian): + target_info = TargetInfo( + quantity="mtt::cartesian", unit="kcal/mol", layout=layout_cartesian + ) - info1.gradients = ["positions"] - assert info1.gradients == ["positions"] + assert target_info.quantity == "mtt::cartesian" + assert target_info.unit == "kcal/mol" + assert target_info.gradients == [] + assert not target_info.per_atom - info1.gradients += ["strain"] - assert info1.gradients == ["positions", "strain"] + expected_start = "TargetInfo(quantity='mtt::cartesian', unit='kcal/mol'" + assert target_info.__repr__()[: len(expected_start)] == expected_start -def test_unit_none_conversion(): - info = TargetInfo(quantity="energy", unit=None) +def test_unit_none_conversion(layout_scalar): + info = TargetInfo(quantity="energy", unit=None, layout=layout_scalar) assert info.unit == "" -def test_length_unit_none_conversion(): +def test_length_unit_none_conversion(layout_scalar): dataset_info = DatasetInfo( length_unit=None, atomic_types=[1, 2, 3], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="kcal/mol")), + targets={ + "energy": TargetInfo( + quantity="energy", unit="kcal/mol", layout=layout_scalar + ) + }, ) assert dataset_info.length_unit == "" -def test_target_info_copy(): - info = TargetInfo(quantity="energy", unit="eV", gradients=["positions"]) - copy = info.copy() - assert copy == info - assert copy is not info - - -def test_target_info_eq(): - info1 = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) - info2 = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) +def test_target_info_eq(layout_scalar): + info1 = TargetInfo(quantity="energy", unit="kcal/mol", layout=layout_scalar) + info2 = TargetInfo(quantity="energy", unit="eV", layout=layout_scalar) assert info1 == info1 assert info1 != info2 -def test_target_info_eq_error(): - info = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) +def test_target_info_eq_error(layout_scalar): + info = TargetInfo(quantity="energy", unit="eV", layout=layout_scalar) match = ( "Comparison between a TargetInfo instance and a list instance is not " @@ -106,152 +184,15 @@ def test_target_info_eq_error(): _ = info == [1, 2, 3] -def test_target_info_update(): - info1 = TargetInfo(quantity="energy", unit="eV", gradients=["strain", "aaa"]) - info2 = TargetInfo(quantity="energy", unit="eV", gradients=["positions"]) - info1.update(info2) - assert info1.gradients == ["aaa", "positions", "strain"] - - -def test_target_info_union(): - info1 = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) - info2 = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) - info_new = info1.union(info2) - assert isinstance(info_new, TargetInfo) - assert info_new.gradients == ["position", "strain"] - - -def test_target_info_update_non_matching_quantity(): - info1 = TargetInfo(quantity="energy", unit="eV") - info2 = TargetInfo(quantity="force", unit="eV") - match = r"Can't update TargetInfo with a different `quantity`: \(energy != force\)" - with pytest.raises(ValueError, match=match): - info1.update(info2) - - -def test_target_info_update_non_matching_unit(): - info1 = TargetInfo(quantity="energy", unit="eV") - info2 = TargetInfo(quantity="energy", unit="kcal") - match = r"Can't update TargetInfo with a different `unit`: \(eV != kcal\)" - with pytest.raises(ValueError, match=match): - info1.update(info2) - - -def test_target_info_update_non_matching_per_atom(): - info1 = TargetInfo(quantity="energy", unit="eV", per_atom=True) - info2 = TargetInfo(quantity="energy", unit="eV", per_atom=False) - match = "Can't update TargetInfo with a different `per_atom` property: " - with pytest.raises(ValueError, match=match): - info1.update(info2) - - -def test_target_info_dict_setitem_new_entry(): - tid = TargetInfoDict() - info = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) - tid["energy"] = info - assert tid["energy"] == info - - -def test_target_info_dict_setitem_update_entry(): - tid = TargetInfoDict() - info1 = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) - info2 = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) - tid["energy"] = info1 - tid["energy"] = info2 - assert tid["energy"].gradients == ["position", "strain"] - - -def test_target_info_dict_setitem_value_error(): - tid = TargetInfoDict() - with pytest.raises(ValueError, match="value to set is not a `TargetInfo` instance"): - tid["energy"] = "not a TargetInfo" - - -def test_target_info_dict_union(): - tid1 = TargetInfoDict() - tid1["energy"] = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) - - tid2 = TargetInfoDict() - tid2["myenergy"] = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) - - merged = tid1.union(tid2) - assert merged["energy"] == tid1["energy"] - assert merged["myenergy"] == tid2["myenergy"] - - -def test_target_info_dict_merge_error(): - tid1 = TargetInfoDict() - tid1["energy"] = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) - - tid2 = TargetInfoDict() - tid2["energy"] = TargetInfo( - quantity="energy", unit="kcal/mol", gradients=["strain"] - ) - - match = r"Can't update TargetInfo with a different `unit`: \(eV != kcal/mol\)" - with pytest.raises(ValueError, match=match): - tid1.union(tid2) - - -def test_target_info_dict_intersection(): - tid1 = TargetInfoDict() - tid1["energy"] = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) - tid1["myenergy"] = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) - - tid2 = TargetInfoDict() - tid2["myenergy"] = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) - - intersection = tid1.intersection(tid2) - assert len(intersection) == 1 - assert intersection["myenergy"] == tid1["myenergy"] - - # Test `&` operator - intersection_and = tid1 & tid2 - assert intersection_and == intersection - - -def test_target_info_dict_intersection_error(): - tid1 = TargetInfoDict() - tid1["energy"] = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) - tid1["myenergy"] = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) - - tid2 = TargetInfoDict() - tid2["myenergy"] = TargetInfo( - quantity="energy", unit="kcal/mol", gradients=["strain"] - ) - - match = ( - r"Intersected items with the same key are not the same. Intersected " - r"keys are myenergy" - ) - - with pytest.raises(ValueError, match=match): - tid1.intersection(tid2) - - -def test_target_info_dict_difference(): - tid1 = TargetInfoDict() - tid1["energy"] = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) - tid1["myenergy"] = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) - - tid2 = TargetInfoDict() - tid2["myenergy"] = TargetInfo( - quantity="energy", unit="kcal/mol", gradients=["strain"] - ) - - difference = tid1.difference(tid2) - assert len(difference) == 1 - assert difference["energy"] == tid1["energy"] - - difference_sub = tid1 - tid2 - assert difference_sub == difference - - -def test_dataset_info(): +def test_dataset_info(layout_scalar): """Tests the DatasetInfo class.""" - targets = TargetInfoDict(energy=TargetInfo(quantity="energy", unit="kcal/mol")) - targets["mtt::U0"] = TargetInfo(quantity="energy", unit="kcal/mol") + targets = { + "energy": TargetInfo(quantity="energy", unit="kcal/mol", layout=layout_scalar) + } + targets["mtt::U0"] = TargetInfo( + quantity="energy", unit="kcal/mol", layout=layout_scalar + ) dataset_info = DatasetInfo( length_unit="angstrom", atomic_types=[3, 1, 2], targets=targets @@ -271,9 +212,13 @@ def test_dataset_info(): assert dataset_info.__repr__() == expected -def test_set_atomic_types(): - targets = TargetInfoDict(energy=TargetInfo(quantity="energy", unit="kcal/mol")) - targets["mtt::U0"] = TargetInfo(quantity="energy", unit="kcal/mol") +def test_set_atomic_types(layout_scalar): + targets = { + "energy": TargetInfo(quantity="energy", unit="kcal/mol", layout=layout_scalar) + } + targets["mtt::U0"] = TargetInfo( + quantity="energy", unit="kcal/mol", layout=layout_scalar + ) dataset_info = DatasetInfo( length_unit="angstrom", atomic_types=[3, 1, 2], targets=targets @@ -286,10 +231,12 @@ def test_set_atomic_types(): assert dataset_info.atomic_types == [1, 4, 5, 7] -def test_dataset_info_copy(): - targets = TargetInfoDict() - targets["energy"] = TargetInfo(quantity="energy", unit="eV") - targets["forces"] = TargetInfo(quantity="mtt::forces", unit="eV/Angstrom") +def test_dataset_info_copy(layout_scalar, layout_cartesian): + targets = {} + targets["energy"] = TargetInfo(quantity="energy", unit="eV", layout=layout_scalar) + targets["mtt::my-target"] = TargetInfo( + quantity="mtt::my-target", unit="eV/Angstrom", layout=layout_cartesian + ) info = DatasetInfo(length_unit="angstrom", atomic_types=[1, 6], targets=targets) copy = info.copy() @@ -298,31 +245,35 @@ def test_dataset_info_copy(): assert copy is not info -def test_dataset_info_update(): - targets = TargetInfoDict() - targets["energy"] = TargetInfo(quantity="energy", unit="eV") +def test_dataset_info_update(layout_scalar, layout_spherical): + targets = {} + targets["energy"] = TargetInfo(quantity="energy", unit="eV", layout=layout_scalar) info = DatasetInfo(length_unit="angstrom", atomic_types=[1, 6], targets=targets) targets2 = targets.copy() - targets2["forces"] = TargetInfo(quantity="mtt::forces", unit="eV/Angstrom") + targets2["mtt::my-target"] = TargetInfo( + quantity="mtt::my-target", unit="eV/Angstrom", layout=layout_spherical + ) info2 = DatasetInfo(length_unit="angstrom", atomic_types=[8], targets=targets2) info.update(info2) assert info.atomic_types == [1, 6, 8] assert info.targets["energy"] == targets["energy"] - assert info.targets["forces"] == targets2["forces"] + assert info.targets["mtt::my-target"] == targets2["mtt::my-target"] -def test_dataset_info_update_non_matching_length_unit(): - targets = TargetInfoDict() - targets["energy"] = TargetInfo(quantity="energy", unit="eV") +def test_dataset_info_update_non_matching_length_unit(layout_scalar): + targets = {} + targets["energy"] = TargetInfo(quantity="energy", unit="eV", layout=layout_scalar) info = DatasetInfo(length_unit="angstrom", atomic_types=[1, 6], targets=targets) targets2 = targets.copy() - targets2["forces"] = TargetInfo(quantity="mtt::forces", unit="eV/Angstrom") + targets2["mtt::my-target"] = TargetInfo( + quantity="mtt::my-target", unit="eV/Angstrom", layout=layout_scalar + ) info2 = DatasetInfo(length_unit="nanometer", atomic_types=[8], targets=targets2) @@ -335,23 +286,25 @@ def test_dataset_info_update_non_matching_length_unit(): info.update(info2) -def test_dataset_info_eq(): - targets = TargetInfoDict() - targets["energy"] = TargetInfo(quantity="energy", unit="eV") +def test_dataset_info_eq(layout_scalar): + targets = {} + targets["energy"] = TargetInfo(quantity="energy", unit="eV", layout=layout_scalar) info = DatasetInfo(length_unit="angstrom", atomic_types=[1, 6], targets=targets) targets2 = targets.copy() - targets2["forces"] = TargetInfo(quantity="mtt::forces", unit="eV/Angstrom") + targets2["my-target"] = TargetInfo( + quantity="mtt::my-target", unit="eV/Angstrom", layout=layout_scalar + ) info2 = DatasetInfo(length_unit="nanometer", atomic_types=[8], targets=targets2) assert info == info assert info != info2 -def test_dataset_info_eq_error(): - targets = TargetInfoDict() - targets["energy"] = TargetInfo(quantity="energy", unit="eV") +def test_dataset_info_eq_error(layout_scalar): + targets = {} + targets["energy"] = TargetInfo(quantity="energy", unit="eV", layout=layout_scalar) info = DatasetInfo(length_unit="angstrom", atomic_types=[1, 6], targets=targets) @@ -363,31 +316,39 @@ def test_dataset_info_eq_error(): _ = info == [1, 2, 3] -def test_dataset_info_update_different_target_info(): - targets = TargetInfoDict() - targets["energy"] = TargetInfo(quantity="energy", unit="eV") +def test_dataset_info_update_different_target_info(layout_scalar): + targets = {} + targets["energy"] = TargetInfo(quantity="energy", unit="eV", layout=layout_scalar) info = DatasetInfo(length_unit="angstrom", atomic_types=[1, 6], targets=targets) - targets2 = TargetInfoDict() - targets2["energy"] = TargetInfo(quantity="energy", unit="eV/Angstrom") + targets2 = {} + targets2["energy"] = TargetInfo( + quantity="energy", unit="eV/Angstrom", layout=layout_scalar + ) info2 = DatasetInfo(length_unit="angstrom", atomic_types=[8], targets=targets2) - match = r"Can't update TargetInfo with a different `unit`: \(eV != eV/Angstrom\)" + match = ( + "Can't update DatasetInfo with different target information for target 'energy'" + ) with pytest.raises(ValueError, match=match): info.update(info2) -def test_dataset_info_union(): +def test_dataset_info_union(layout_scalar, layout_cartesian): """Tests the union method.""" - targets = TargetInfoDict() - targets["energy"] = TargetInfo(quantity="energy", unit="eV") - targets["forces"] = TargetInfo(quantity="mtt::forces", unit="eV/Angstrom") + targets = {} + targets["energy"] = TargetInfo(quantity="energy", unit="eV", layout=layout_scalar) + targets["forces"] = TargetInfo( + quantity="mtt::forces", unit="eV/Angstrom", layout=layout_scalar + ) info = DatasetInfo(length_unit="angstrom", atomic_types=[1, 6], targets=targets) other_targets = targets.copy() - other_targets["mtt::stress"] = TargetInfo(quantity="mtt::stress", unit="GPa") + other_targets["mtt::stress"] = TargetInfo( + quantity="mtt::stress", unit="GPa", layout=layout_cartesian + ) other_info = DatasetInfo( length_unit="angstrom", atomic_types=[1], targets=other_targets @@ -606,7 +567,7 @@ def test_collate_fn(): assert isinstance(batch[1], dict) -def test_get_stats(): +def test_get_stats(layout_scalar): """Tests the get_stats method of Dataset and Subset.""" systems = read_systems(RESOURCES_PATH / "qm9_reduced_100.xyz") @@ -644,8 +605,8 @@ def test_get_stats(): length_unit="angstrom", atomic_types=[1, 6], targets={ - "mtt::U0": TargetInfo(quantity="energy", unit="eV"), - "energy": TargetInfo(quantity="energy", unit="eV"), + "mtt::U0": TargetInfo(quantity="energy", unit="eV", layout=layout_scalar), + "energy": TargetInfo(quantity="energy", unit="eV", layout=layout_scalar), }, ) diff --git a/tests/utils/data/test_readers.py b/tests/utils/data/test_readers.py index 92b5e656b..5876d8e70 100644 --- a/tests/utils/data/test_readers.py +++ b/tests/utils/data/test_readers.py @@ -11,7 +11,7 @@ from omegaconf import OmegaConf from test_targets_ase import ase_system, ase_systems -from metatrain.utils.data.dataset import TargetInfo, TargetInfoDict +from metatrain.utils.data.dataset import TargetInfo from metatrain.utils.data.readers import ( read_energy, read_forces, @@ -178,7 +178,7 @@ def test_read_targets(stress_dict, virial_dict, monkeypatch, tmp_path, caplog): assert any(["Forces found" in rec.message for rec in caplog.records]) assert type(result) is dict - assert type(target_info_dict) is TargetInfoDict + assert type(target_info_dict) is dict if stress_dict: assert any(["Stress found" in rec.message for rec in caplog.records]) diff --git a/tests/utils/test_additive.py b/tests/utils/test_additive.py index fd2179e5e..a1b9a7af4 100644 --- a/tests/utils/test_additive.py +++ b/tests/utils/test_additive.py @@ -8,12 +8,13 @@ from omegaconf import OmegaConf from metatrain.utils.additive import ZBL, CompositionModel, remove_additive -from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo from metatrain.utils.data.readers import read_systems, read_targets from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) +from metatrain.utils.testing import energy_layout RESOURCES_PATH = Path(__file__).parents[1] / "resources" @@ -79,14 +80,12 @@ def test_composition_model_train(): dataset_info=DatasetInfo( length_unit="angstrom", atomic_types=[1, 8], - targets=TargetInfoDict( - { - "energy": TargetInfo( - quantity="energy", - per_atom=False, - ) - } - ), + targets={ + "energy": TargetInfo( + quantity="energy", + layout=energy_layout, + ) + }, ), ) @@ -207,14 +206,12 @@ def test_composition_model_torchscript(tmpdir): dataset_info=DatasetInfo( length_unit="angstrom", atomic_types=[1, 8], - targets=TargetInfoDict( - { - "energy": TargetInfo( - quantity="energy", - per_atom=False, - ) - } - ), + targets={ + "energy": TargetInfo( + quantity="energy", + layout=energy_layout, + ) + }, ), ) composition_model = torch.jit.script(composition_model) @@ -335,14 +332,12 @@ def test_composition_model_missing_types(): dataset_info=DatasetInfo( length_unit="angstrom", atomic_types=[1], - targets=TargetInfoDict( - { - "energy": TargetInfo( - quantity="energy", - per_atom=False, - ) - } - ), + targets={ + "energy": TargetInfo( + quantity="energy", + layout=energy_layout, + ) + }, ), ) with pytest.raises( @@ -356,14 +351,12 @@ def test_composition_model_missing_types(): dataset_info=DatasetInfo( length_unit="angstrom", atomic_types=[1, 8, 100], - targets=TargetInfoDict( - { - "energy": TargetInfo( - quantity="energy", - per_atom=False, - ) - } - ), + targets={ + "energy": TargetInfo( + quantity="energy", + layout=energy_layout, + ) + }, ), ) with pytest.warns( @@ -387,14 +380,12 @@ def test_composition_model_wrong_target(): dataset_info=DatasetInfo( length_unit="angstrom", atomic_types=[1], - targets=TargetInfoDict( - { - "energy": TargetInfo( - quantity="FOO", - per_atom=False, - ) - } - ), + targets={ + "energy": TargetInfo( + quantity="FOO", + layout=energy_layout, + ) + }, ), ) diff --git a/tests/utils/test_evaluate_model.py b/tests/utils/test_evaluate_model.py index e2bd81eca..9c147378d 100644 --- a/tests/utils/test_evaluate_model.py +++ b/tests/utils/test_evaluate_model.py @@ -10,6 +10,7 @@ get_requested_neighbor_lists, get_system_with_neighbor_lists, ) +from metatrain.utils.testing import energy_force_stress_layout from . import MODEL_HYPERS, RESOURCES_PATH @@ -27,7 +28,7 @@ def test_evaluate_model(training, exported): targets = { "energy": TargetInfo( - quantity="energy", unit="eV", gradients=["positions", "strain"] + quantity="energy", unit="eV", layout=energy_force_stress_layout ) } diff --git a/tests/utils/test_export.py b/tests/utils/test_export.py index 439a279f4..a0dcc1784 100644 --- a/tests/utils/test_export.py +++ b/tests/utils/test_export.py @@ -7,6 +7,7 @@ from metatrain.experimental.soap_bpnn import __model__ from metatrain.utils.data import DatasetInfo, TargetInfo from metatrain.utils.export import export, is_exported +from metatrain.utils.testing import energy_layout from . import MODEL_HYPERS, RESOURCES_PATH @@ -18,7 +19,9 @@ def test_export(tmp_path): dataset_info = DatasetInfo( length_unit="angstrom", atomic_types={1}, - targets={"energy": TargetInfo(quantity="energy", unit="eV")}, + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = __model__(model_hypers=MODEL_HYPERS, dataset_info=dataset_info) @@ -48,7 +51,9 @@ def test_reexport(monkeypatch, tmp_path): dataset_info = DatasetInfo( length_unit="angstrom", atomic_types={1}, - targets={"energy": TargetInfo(quantity="energy", unit="eV")}, + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = __model__(model_hypers=MODEL_HYPERS, dataset_info=dataset_info) @@ -79,7 +84,9 @@ def test_length_units_warning(): dataset_info = DatasetInfo( length_unit="angstrom", atomic_types={1}, - targets={"energy": TargetInfo(quantity="energy", unit="eV")}, + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = __model__(model_hypers=MODEL_HYPERS, dataset_info=dataset_info) @@ -100,7 +107,7 @@ def test_units_warning(): dataset_info = DatasetInfo( length_unit="angstrom", atomic_types={1}, - targets={"mtt::output": TargetInfo(quantity="energy")}, + targets={"mtt::output": TargetInfo(quantity="energy", layout=energy_layout)}, ) model = __model__(model_hypers=MODEL_HYPERS, dataset_info=dataset_info) diff --git a/tests/utils/test_external_naming.py b/tests/utils/test_external_naming.py index 666d1c890..0e3db1371 100644 --- a/tests/utils/test_external_naming.py +++ b/tests/utils/test_external_naming.py @@ -1,14 +1,15 @@ from metatrain.utils.data.dataset import TargetInfo from metatrain.utils.external_naming import to_external_name, to_internal_name +from metatrain.utils.testing import energy_layout def test_to_external_name(): """Tests the to_external_name function.""" quantities = { - "energy": TargetInfo(quantity="energy"), - "mtt::free_energy": TargetInfo(quantity="energy"), - "mtt::foo": TargetInfo(quantity="bar"), + "energy": TargetInfo(quantity="energy", layout=energy_layout), + "mtt::free_energy": TargetInfo(quantity="energy", layout=energy_layout), + "mtt::foo": TargetInfo(quantity="bar", layout=energy_layout), } assert to_external_name("energy_positions_gradients", quantities) == "forces" diff --git a/tests/utils/test_output_gradient.py b/tests/utils/test_output_gradient.py index 7bad9279e..94b475f6f 100644 --- a/tests/utils/test_output_gradient.py +++ b/tests/utils/test_output_gradient.py @@ -6,6 +6,11 @@ from metatrain.experimental.soap_bpnn import __model__ from metatrain.utils.data import DatasetInfo, TargetInfo, read_systems from metatrain.utils.output_gradient import compute_gradient +from metatrain.utils.testing import ( + energy_force_layout, + energy_force_stress_layout, + energy_stress_layout, +) from . import MODEL_HYPERS, RESOURCES_PATH @@ -19,7 +24,7 @@ def test_forces(is_training): atomic_types={1, 6, 7, 8}, targets={ "energy": TargetInfo( - quantity="energy", unit="eV", per_atom=False, gradients=["positions"] + quantity="energy", unit="eV", layout=energy_force_layout ) }, ) @@ -76,7 +81,7 @@ def test_virial(is_training): atomic_types={6}, targets={ "energy": TargetInfo( - quantity="energy", unit="eV", per_atom=False, gradients=["strain"] + quantity="energy", unit="eV", layout=energy_stress_layout ) }, ) @@ -147,8 +152,7 @@ def test_both(is_training): "energy": TargetInfo( quantity="energy", unit="eV", - per_atom=False, - gradients=["positions", "strain"], + layout=energy_force_stress_layout, ) }, ) From 60b6b0d7a156b28b71f7558a9e3045ce74411acf Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Mon, 28 Oct 2024 10:34:37 +0100 Subject: [PATCH 02/12] Change `TensorMap`s inside `TargetInfo` to `float64` to avoid serialization problems --- src/metatrain/cli/eval.py | 9 ++++++--- src/metatrain/utils/data/readers/readers.py | 2 +- src/metatrain/utils/testing.py | 9 ++++++--- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/metatrain/cli/eval.py b/src/metatrain/cli/eval.py index 25bcbdb2c..bf17745dd 100644 --- a/src/metatrain/cli/eval.py +++ b/src/metatrain/cli/eval.py @@ -371,7 +371,8 @@ def eval_model( def _get_energy_layout(strain_gradient: bool) -> TensorMap: block = TensorBlock( - values=torch.empty(0, 1), + # float64: otherwise metatensor can't serialize + values=torch.empty(0, 1, dtype=torch.float64), samples=Labels( names=["system"], values=torch.empty((0, 1), dtype=torch.int32), @@ -380,7 +381,8 @@ def _get_energy_layout(strain_gradient: bool) -> TensorMap: properties=Labels.range("energy", 1), ) position_gradient_block = TensorBlock( - values=torch.empty(0, 3, 1), + # float64: otherwise metatensor can't serialize + values=torch.empty(0, 3, 1, dtype=torch.float64), samples=Labels( names=["sample", "atom"], values=torch.empty((0, 2), dtype=torch.int32), @@ -397,7 +399,8 @@ def _get_energy_layout(strain_gradient: bool) -> TensorMap: if strain_gradient: strain_gradient_block = TensorBlock( - values=torch.empty(0, 3, 3, 1), + # float64: otherwise metatensor can't serialize + values=torch.empty(0, 3, 3, 1, dtype=torch.float64), samples=Labels( names=["sample", "atom"], values=torch.empty((0, 2), dtype=torch.int32), diff --git a/src/metatrain/utils/data/readers/readers.py b/src/metatrain/utils/data/readers/readers.py index dc81af7c3..7a6076143 100644 --- a/src/metatrain/utils/data/readers/readers.py +++ b/src/metatrain/utils/data/readers/readers.py @@ -306,7 +306,7 @@ def _empty_tensor_block_like(tensor_block: TensorBlock) -> TensorBlock: new_block = TensorBlock( values=torch.empty( (0,) + tensor_block.values.shape[1:], - dtype=tensor_block.values.dtype, + dtype=torch.float64, # metatensor can't serialize otherwise device=tensor_block.values.device, ), samples=Labels( diff --git a/src/metatrain/utils/testing.py b/src/metatrain/utils/testing.py index 62dc7382c..faedbdb00 100644 --- a/src/metatrain/utils/testing.py +++ b/src/metatrain/utils/testing.py @@ -6,7 +6,8 @@ block = TensorBlock( - values=torch.empty(0, 1), + # float64: otherwise metatensor can't serialize + values=torch.empty(0, 1, dtype=torch.float64), samples=Labels( names=["system"], values=torch.empty((0, 1), dtype=torch.int32), @@ -21,7 +22,8 @@ block_with_position_gradients = block.copy() position_gradient_block = TensorBlock( - values=torch.empty(0, 3, 1), + # float64: otherwise metatensor can't serialize + values=torch.empty(0, 3, 1, dtype=torch.float64), samples=Labels( names=["sample", "atom"], values=torch.empty((0, 2), dtype=torch.int32), @@ -42,7 +44,8 @@ block_with_position_and_strain_gradients = block_with_position_gradients.copy() strain_gradient_block = TensorBlock( - values=torch.empty(0, 3, 3, 1), + # float64: otherwise metatensor can't serialize + values=torch.empty(0, 3, 3, 1, dtype=torch.float64), samples=Labels( names=["sample", "atom"], values=torch.empty((0, 2), dtype=torch.int32), From 4c614f735af2513363e93cf57a9ca47ec45725d7 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Mon, 28 Oct 2024 10:45:57 +0100 Subject: [PATCH 03/12] Better documentation --- src/metatrain/utils/data/dataset.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/metatrain/utils/data/dataset.py b/src/metatrain/utils/data/dataset.py index f9761b129..79eb7b794 100644 --- a/src/metatrain/utils/data/dataset.py +++ b/src/metatrain/utils/data/dataset.py @@ -17,9 +17,11 @@ class TargetInfo: :param quantity: The physical quantity of the target (e.g., "energy"). :param layout: The layout of the target, as a ``TensorMap`` with 0 samples. - This ``TensorMap`` will contain important information such as the names of + This ``TensorMap`` will be used to retrieve the names of the ``samples``, as well as the ``components`` and ``properties`` of the - target and their gradients. + target and their gradients. For example, this allows to infer the type of + the target (scalar, Cartesian tensor, spherical tensor), whether it is per + atom, the names of its gradients, etc. :param unit: The unit of the target. If :py:obj:`None` the ``unit`` will be set to an empty string ``""``. """ @@ -37,7 +39,7 @@ def __init__( self._check_layout(layout) - self.quantity = quantity + self.quantity = quantity # float64: otherwise metatensor can't serialize self.layout = layout self.unit = unit if unit is not None else "" From 172af0b6f222e44ab856c28452f4ef6ab85c3c97 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Tue, 29 Oct 2024 19:05:39 +0100 Subject: [PATCH 04/12] Upgrade to `metatensor-torch` 0.6.0 --- pyproject.toml | 4 ++-- .../experimental/alchemical_model/tests/test_exported.py | 1 + .../alchemical_model/tests/test_functionality.py | 1 + src/metatrain/experimental/pet/tests/test_exported.py | 1 + src/metatrain/experimental/pet/tests/test_functionality.py | 3 +++ .../experimental/soap_bpnn/tests/test_exported.py | 1 + .../experimental/soap_bpnn/tests/test_functionality.py | 5 +++++ .../experimental/soap_bpnn/tests/test_torchscript.py | 2 ++ src/metatrain/utils/evaluate_model.py | 3 +++ tests/utils/data/test_system_to_ase.py | 1 + tests/utils/data/test_target_writers.py | 1 + tests/utils/test_additive.py | 7 +++++++ tests/utils/test_output_gradient.py | 6 ++++++ tests/utils/test_per_atom.py | 3 +++ tests/utils/test_transfer.py | 2 ++ 15 files changed, 39 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7e3f2d703..d78526820 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,8 +11,8 @@ authors = [{name = "metatrain developers"}] dependencies = [ "ase < 3.23.0", "metatensor-learn==0.2.3", - "metatensor-operations==0.2.3", - "metatensor-torch==0.5.5", + "metatensor-operations==0.3.0", + "metatensor-torch==0.6.0", "jsonschema", "omegaconf", "python-hostlist", diff --git a/src/metatrain/experimental/alchemical_model/tests/test_exported.py b/src/metatrain/experimental/alchemical_model/tests/test_exported.py index 891983693..8b06c2d74 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_exported.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_exported.py @@ -33,6 +33,7 @@ def test_to(device, dtype): types=torch.tensor([6, 6]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), + pbc=torch.tensor([False, False, False]), ) requested_neighbor_lists = get_requested_neighbor_lists(exported) system = get_system_with_neighbor_lists(system, requested_neighbor_lists) diff --git a/src/metatrain/experimental/alchemical_model/tests/test_functionality.py b/src/metatrain/experimental/alchemical_model/tests/test_functionality.py index 7ee3331af..9f998f925 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_functionality.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_functionality.py @@ -27,6 +27,7 @@ def test_prediction_subset_elements(): types=torch.tensor([6, 6]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), + pbc=torch.tensor([False, False, False]), ) requested_neighbor_lists = get_requested_neighbor_lists(model) system = get_system_with_neighbor_lists(system, requested_neighbor_lists) diff --git a/src/metatrain/experimental/pet/tests/test_exported.py b/src/metatrain/experimental/pet/tests/test_exported.py index f67a15e4c..5ed9ecf6e 100644 --- a/src/metatrain/experimental/pet/tests/test_exported.py +++ b/src/metatrain/experimental/pet/tests/test_exported.py @@ -61,6 +61,7 @@ def test_to(device): types=torch.tensor([6, 6]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), + pbc=torch.tensor([False, False, False]), ) requested_neighbor_lists = get_requested_neighbor_lists(exported) system = get_system_with_neighbor_lists(system, requested_neighbor_lists) diff --git a/src/metatrain/experimental/pet/tests/test_functionality.py b/src/metatrain/experimental/pet/tests/test_functionality.py index ddf527603..35c44be0b 100644 --- a/src/metatrain/experimental/pet/tests/test_functionality.py +++ b/src/metatrain/experimental/pet/tests/test_functionality.py @@ -76,6 +76,7 @@ def test_prediction(): types=torch.tensor([6, 6]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), + pbcs=torch.tensor([False, False, False]), ) requested_neighbor_lists = get_requested_neighbor_lists(model) system = get_system_with_neighbor_lists(system, requested_neighbor_lists) @@ -126,6 +127,7 @@ def test_per_atom_predictions_functionality(): types=torch.tensor([6, 6]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), + pbcs=torch.tensor([False, False, False]), ) requested_neighbor_lists = get_requested_neighbor_lists(model) system = get_system_with_neighbor_lists(system, requested_neighbor_lists) @@ -177,6 +179,7 @@ def test_selected_atoms_functionality(): types=torch.tensor([6, 6]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), + pbcs=torch.tensor([False, False, False]), ) requested_neighbor_lists = get_requested_neighbor_lists(model) system = get_system_with_neighbor_lists(system, requested_neighbor_lists) diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_exported.py b/src/metatrain/experimental/soap_bpnn/tests/test_exported.py index 63242161e..ff926aa60 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_exported.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_exported.py @@ -33,6 +33,7 @@ def test_to(device, dtype): types=torch.tensor([6, 6]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), + pbcs=torch.tensor([False, False, False]), ) requested_neighbor_lists = get_requested_neighbor_lists(exported) system = get_system_with_neighbor_lists(system, requested_neighbor_lists) diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py b/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py index 25dd250b6..12bbe4c7a 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py @@ -28,6 +28,7 @@ def test_prediction_subset_elements(): types=torch.tensor([6, 6]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), + pbcs=torch.tensor([False, False, False]), ) model( [system], @@ -56,6 +57,7 @@ def test_prediction_subset_atoms(): [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0]], ), cell=torch.zeros(3, 3), + pbcs=torch.tensor([False, False, False]), ) energy_monomer = model( @@ -76,6 +78,7 @@ def test_prediction_subset_atoms(): ], ), cell=torch.zeros(3, 3), + pbcs=torch.tensor([False, False, False]), ) selection_labels = metatensor.torch.Labels( @@ -119,6 +122,7 @@ def test_output_last_layer_features(): [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]], ), cell=torch.zeros(3, 3), + pbcs=torch.tensor([False, False, False]), ) # last-layer features per atom: @@ -190,6 +194,7 @@ def test_output_per_atom(): [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]], ), cell=torch.zeros(3, 3), + pbcs=torch.tensor([False, False, False]), ) outputs = model( diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py b/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py index 4d6b89898..9e41046a3 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py @@ -26,6 +26,7 @@ def test_torchscript(): [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]] ), cell=torch.zeros(3, 3), + pbcs=torch.tensor([False, False, False]), ) model( [system], @@ -52,6 +53,7 @@ def test_torchscript_with_identity(): [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]] ), cell=torch.zeros(3, 3), + pbcs=torch.tensor([False, False, False]), ) model( [system], diff --git a/src/metatrain/utils/evaluate_model.py b/src/metatrain/utils/evaluate_model.py index 48447b879..5b4500a21 100644 --- a/src/metatrain/utils/evaluate_model.py +++ b/src/metatrain/utils/evaluate_model.py @@ -278,6 +278,7 @@ def _prepare_system( positions=system.positions @ strain, cell=system.cell @ strain, types=system.types, + pbc=system.pbc, ) else: if positions_grad: @@ -285,6 +286,7 @@ def _prepare_system( positions=system.positions.detach().clone().requires_grad_(True), cell=system.cell, types=system.types, + pbc=system.pbc, ) strain = None else: @@ -292,6 +294,7 @@ def _prepare_system( positions=system.positions, cell=system.cell, types=system.types, + pbc=system.pbc, ) strain = None diff --git a/tests/utils/data/test_system_to_ase.py b/tests/utils/data/test_system_to_ase.py index e6062b52d..3d66f7e7c 100644 --- a/tests/utils/data/test_system_to_ase.py +++ b/tests/utils/data/test_system_to_ase.py @@ -11,6 +11,7 @@ def test_system_to_ase(): positions=torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]), types=torch.tensor([1, 8]), cell=torch.tensor([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]), + pbc=torch.tensor([True, True, True]), ) # Convert the system to an ASE atoms object diff --git a/tests/utils/data/test_target_writers.py b/tests/utils/data/test_target_writers.py index 190d37fbe..525ccd446 100644 --- a/tests/utils/data/test_target_writers.py +++ b/tests/utils/data/test_target_writers.py @@ -20,6 +20,7 @@ def systems_capabilities_predictions( types=torch.tensor([1, 1]), positions=torch.tensor([[0, 0, 0], [0, 0, 0.74]]), cell=cell, + pbc=torch.logical_not(torch.all(cell == 0, dim=1)), ), ] diff --git a/tests/utils/test_additive.py b/tests/utils/test_additive.py index fd2179e5e..3daf3c980 100644 --- a/tests/utils/test_additive.py +++ b/tests/utils/test_additive.py @@ -33,6 +33,7 @@ def test_composition_model_train(): positions=torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float64), types=torch.tensor([8]), cell=torch.eye(3, dtype=torch.float64), + pbc=torch.tensor([True, True, True]), ), System( positions=torch.tensor( @@ -40,6 +41,7 @@ def test_composition_model_train(): ), types=torch.tensor([1, 1, 8]), cell=torch.eye(3, dtype=torch.float64), + pbc=torch.tensor([True, True, True]), ), System( positions=torch.tensor( @@ -55,6 +57,7 @@ def test_composition_model_train(): ), types=torch.tensor([1, 1, 8, 1, 1, 8]), cell=torch.eye(3, dtype=torch.float64), + pbc=torch.tensor([True, True, True]), ), ] energies = [1.0, 5.0, 10.0] @@ -200,6 +203,7 @@ def test_composition_model_torchscript(tmpdir): positions=torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float64), types=torch.tensor([8]), cell=torch.eye(3, dtype=torch.float64), + pbc=torch.tensor([True, True, True]), ) composition_model = CompositionModel( @@ -289,6 +293,7 @@ def test_composition_model_missing_types(): positions=torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float64), types=torch.tensor([8]), cell=torch.eye(3, dtype=torch.float64), + pbc=torch.tensor([True, True, True]), ), System( positions=torch.tensor( @@ -296,6 +301,7 @@ def test_composition_model_missing_types(): ), types=torch.tensor([1, 1, 8]), cell=torch.eye(3, dtype=torch.float64), + pbc=torch.tensor([True, True, True]), ), System( positions=torch.tensor( @@ -311,6 +317,7 @@ def test_composition_model_missing_types(): ), types=torch.tensor([1, 1, 8, 1, 1, 8]), cell=torch.eye(3, dtype=torch.float64), + pbc=torch.tensor([True, True, True]), ), ] energies = [1.0, 5.0, 10.0] diff --git a/tests/utils/test_output_gradient.py b/tests/utils/test_output_gradient.py index 7bad9279e..513d43930 100644 --- a/tests/utils/test_output_gradient.py +++ b/tests/utils/test_output_gradient.py @@ -32,6 +32,7 @@ def test_forces(is_training): positions=system.positions.requires_grad_(True), cell=system.cell, types=system.types, + pbc=system.pbc, ) for system in systems ] @@ -50,6 +51,7 @@ def test_forces(is_training): positions=system.positions.requires_grad_(True), cell=system.cell, types=system.types, + pbc=system.pbc, ) for system in systems ] @@ -96,6 +98,7 @@ def test_virial(is_training): positions=system.positions @ strain, cell=system.cell @ strain, types=system.types, + pbc=system.pbc, ) for system, strain in zip(systems, strains) ] @@ -121,6 +124,7 @@ def test_virial(is_training): positions=system.positions @ strain, cell=system.cell @ strain, types=system.types, + pbc=system.pbc, ) for system, strain in zip(systems, strains) ] @@ -170,6 +174,7 @@ def test_both(is_training): positions=system.positions @ strain, cell=system.cell @ strain, types=system.types, + pbc=system.pbc, ) for system, strain in zip(systems, strains) ] @@ -193,6 +198,7 @@ def test_both(is_training): positions=system.positions @ strain, cell=system.cell @ strain, types=system.types, + pbc=system.pbc, ) for system, strain in zip(systems, strains) ] diff --git a/tests/utils/test_per_atom.py b/tests/utils/test_per_atom.py index 3f4f34945..0f84180e7 100644 --- a/tests/utils/test_per_atom.py +++ b/tests/utils/test_per_atom.py @@ -13,16 +13,19 @@ def test_average_by_num_atoms(): positions=torch.tensor([[0.0, 0.0, 0.0]]), cell=torch.eye(3), types=torch.tensor([0]), + pbc=torch.tensor([True, True, True]), ), System( positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), cell=torch.eye(3), types=torch.tensor([0, 0]), + pbc=torch.tensor([True, True, True]), ), System( positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), cell=torch.eye(3), types=torch.tensor([0, 0, 0]), + pbc=torch.tensor([True, True, True]), ), ] diff --git a/tests/utils/test_transfer.py b/tests/utils/test_transfer.py index b669c0c29..283842779 100644 --- a/tests/utils/test_transfer.py +++ b/tests/utils/test_transfer.py @@ -14,6 +14,7 @@ def test_systems_and_targets_to_dtype(): positions=torch.tensor([[1.0, 1.0, 1.0]]), cell=torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), types=torch.tensor([1]), + pbc=torch.tensor([True, True, True]), ) targets = TensorMap( keys=Labels.single(), @@ -39,6 +40,7 @@ def test_systems_and_targets_to_dtype_and_device(): positions=torch.tensor([[1.0, 1.0, 1.0]]), cell=torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), types=torch.tensor([1]), + pbc=torch.tensor([True, True, True]), ) targets = TensorMap( keys=Labels.single(), From 0f3bfd7b19d98ea3a9be62d13d10f6ba55a18e6b Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Wed, 30 Oct 2024 15:35:41 +0100 Subject: [PATCH 05/12] Upgrade `metatensor-learn` --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d78526820..6e09fc511 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ authors = [{name = "metatrain developers"}] dependencies = [ "ase < 3.23.0", - "metatensor-learn==0.2.3", + "metatensor-learn==0.3.0", "metatensor-operations==0.3.0", "metatensor-torch==0.6.0", "jsonschema", From 0e8e2f71bd0dcf3c1bad534e0ae74270c84c5cb3 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Wed, 30 Oct 2024 17:30:22 +0100 Subject: [PATCH 06/12] Update strict NL --- pyproject.toml | 4 ++-- src/metatrain/experimental/alchemical_model/model.py | 1 + .../tests/test_torch_alchemical_compatibility.py | 1 + src/metatrain/experimental/pet/model.py | 1 + .../experimental/pet/tests/test_pet_compatibility.py | 4 ++-- src/metatrain/utils/additive/zbl.py | 1 + src/metatrain/utils/metrics.py | 4 ++-- tests/utils/test_neighbor_list.py | 8 ++++---- 8 files changed, 14 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6e09fc511..f75f0c018 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ build-backend = "setuptools.build_meta" [project.optional-dependencies] soap-bpnn = [ - "rascaline-torch @ git+https://github.com/luthaf/rascaline@d181b28#subdirectory=python/rascaline-torch", + "rascaline-torch @ git+https://github.com/luthaf/rascaline@b70b19e#subdirectory=python/rascaline-torch", ] alchemical-model = [ "torch_alchemical @ git+https://github.com/abmazitov/torch_alchemical.git@51ff519", @@ -68,7 +68,7 @@ pet = [ "pet @ git+https://github.com/lab-cosmo/pet@7eddb2e", ] gap = [ - "rascaline-torch @ git+https://github.com/luthaf/rascaline@d181b28#subdirectory=python/rascaline-torch", + "rascaline-torch @ git+https://github.com/luthaf/rascaline@b70b19e#subdirectory=python/rascaline-torch", "skmatter", "metatensor-learn", "scipy", diff --git a/src/metatrain/experimental/alchemical_model/model.py b/src/metatrain/experimental/alchemical_model/model.py index 8ffbd8dc7..4e985dc75 100644 --- a/src/metatrain/experimental/alchemical_model/model.py +++ b/src/metatrain/experimental/alchemical_model/model.py @@ -78,6 +78,7 @@ def requested_neighbor_lists(self) -> List[NeighborListOptions]: NeighborListOptions( cutoff=self.cutoff, full_list=True, + strict=True, ) ] diff --git a/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py b/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py index 03b7ef1df..7b560e786 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py @@ -29,6 +29,7 @@ nl_options = NeighborListOptions( cutoff=5.0, full_list=True, + strict=True, ) systems = [get_system_with_neighbor_lists(system, [nl_options]) for system in systems] diff --git a/src/metatrain/experimental/pet/model.py b/src/metatrain/experimental/pet/model.py index bf567dee1..512dc319b 100644 --- a/src/metatrain/experimental/pet/model.py +++ b/src/metatrain/experimental/pet/model.py @@ -73,6 +73,7 @@ def requested_neighbor_lists( NeighborListOptions( cutoff=self.cutoff, full_list=True, + strict=True, ) ] diff --git a/src/metatrain/experimental/pet/tests/test_pet_compatibility.py b/src/metatrain/experimental/pet/tests/test_pet_compatibility.py index 04ecae1bc..4a2bdd71e 100644 --- a/src/metatrain/experimental/pet/tests/test_pet_compatibility.py +++ b/src/metatrain/experimental/pet/tests/test_pet_compatibility.py @@ -59,7 +59,7 @@ def test_batch_dicts_compatibility(cutoff): structure = ase.io.read(DATASET_PATH) atomic_types = sorted(set(structure.numbers)) system = systems_to_torch(structure) - options = NeighborListOptions(cutoff=cutoff, full_list=True) + options = NeighborListOptions(cutoff=cutoff, full_list=True, strict=True) system = get_system_with_neighbor_lists(system, [options]) ARCHITECTURAL_HYPERS = Hypers(DEFAULT_HYPERS["model"]) @@ -121,7 +121,7 @@ def test_predictions_compatibility(cutoff): model.set_trained_model(raw_pet) system = systems_to_torch(structure) - options = NeighborListOptions(cutoff=cutoff, full_list=True) + options = NeighborListOptions(cutoff=cutoff, full_list=True, strict=True) system = get_system_with_neighbor_lists(system, [options]) evaluation_options = ModelEvaluationOptions( diff --git a/src/metatrain/utils/additive/zbl.py b/src/metatrain/utils/additive/zbl.py index edf2ec7b2..cec2ee96e 100644 --- a/src/metatrain/utils/additive/zbl.py +++ b/src/metatrain/utils/additive/zbl.py @@ -251,6 +251,7 @@ def requested_neighbor_lists(self) -> List[NeighborListOptions]: NeighborListOptions( cutoff=self.cutoff_radius, full_list=True, + strict=True, ) ] diff --git a/src/metatrain/utils/metrics.py b/src/metatrain/utils/metrics.py index 357b567d0..29ace5104 100644 --- a/src/metatrain/utils/metrics.py +++ b/src/metatrain/utils/metrics.py @@ -8,7 +8,7 @@ class RMSEAccumulator: """Accumulates the RMSE between predictions and targets for an arbitrary number of keys, each corresponding to one target.""" - def __init__(self): + def __init__(self) -> None: """Initialize the accumulator.""" self.information: Dict[str, Tuple[float, int]] = {} @@ -91,7 +91,7 @@ class MAEAccumulator: """Accumulates the MAE between predictions and targets for an arbitrary number of keys, each corresponding to one target.""" - def __init__(self): + def __init__(self) -> None: """Initialize the accumulator.""" self.information: Dict[str, Tuple[float, int]] = {} diff --git a/tests/utils/test_neighbor_list.py b/tests/utils/test_neighbor_list.py index 9c35e05d5..768114d6c 100644 --- a/tests/utils/test_neighbor_list.py +++ b/tests/utils/test_neighbor_list.py @@ -14,9 +14,9 @@ def test_attach_neighbor_lists(): systems = read_systems_ase(filename) requested_neighbor_lists = [ - NeighborListOptions(cutoff=4.0, full_list=True), - NeighborListOptions(cutoff=5.0, full_list=False), - NeighborListOptions(cutoff=6.0, full_list=True), + NeighborListOptions(cutoff=4.0, full_list=True, strict=True), + NeighborListOptions(cutoff=5.0, full_list=False, strict=True), + NeighborListOptions(cutoff=6.0, full_list=True, strict=True), ] new_system = get_system_with_neighbor_lists(systems[0], requested_neighbor_lists) @@ -25,7 +25,7 @@ def test_attach_neighbor_lists(): assert requested_neighbor_lists[1] in new_system.known_neighbor_lists() assert requested_neighbor_lists[2] in new_system.known_neighbor_lists() - extraneous_nl = NeighborListOptions(cutoff=5.0, full_list=True) + extraneous_nl = NeighborListOptions(cutoff=5.0, full_list=True, strict=True) assert extraneous_nl not in new_system.known_neighbor_lists() for nl_options in new_system.known_neighbor_lists(): From d632dc2f68a6b92b96e55874de2b2cbe7a8892d4 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Wed, 30 Oct 2024 17:39:08 +0100 Subject: [PATCH 07/12] Fix PBCs --- .../experimental/pet/tests/test_functionality.py | 6 +++--- .../experimental/soap_bpnn/tests/test_exported.py | 2 +- .../experimental/soap_bpnn/tests/test_functionality.py | 10 +++++----- .../experimental/soap_bpnn/tests/test_torchscript.py | 4 ++-- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/metatrain/experimental/pet/tests/test_functionality.py b/src/metatrain/experimental/pet/tests/test_functionality.py index 35c44be0b..b0f5771ea 100644 --- a/src/metatrain/experimental/pet/tests/test_functionality.py +++ b/src/metatrain/experimental/pet/tests/test_functionality.py @@ -76,7 +76,7 @@ def test_prediction(): types=torch.tensor([6, 6]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), - pbcs=torch.tensor([False, False, False]), + pbc=torch.tensor([False, False, False]), ) requested_neighbor_lists = get_requested_neighbor_lists(model) system = get_system_with_neighbor_lists(system, requested_neighbor_lists) @@ -127,7 +127,7 @@ def test_per_atom_predictions_functionality(): types=torch.tensor([6, 6]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), - pbcs=torch.tensor([False, False, False]), + pbc=torch.tensor([False, False, False]), ) requested_neighbor_lists = get_requested_neighbor_lists(model) system = get_system_with_neighbor_lists(system, requested_neighbor_lists) @@ -179,7 +179,7 @@ def test_selected_atoms_functionality(): types=torch.tensor([6, 6]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), - pbcs=torch.tensor([False, False, False]), + pbc=torch.tensor([False, False, False]), ) requested_neighbor_lists = get_requested_neighbor_lists(model) system = get_system_with_neighbor_lists(system, requested_neighbor_lists) diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_exported.py b/src/metatrain/experimental/soap_bpnn/tests/test_exported.py index ff926aa60..349e89cfd 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_exported.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_exported.py @@ -33,7 +33,7 @@ def test_to(device, dtype): types=torch.tensor([6, 6]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), - pbcs=torch.tensor([False, False, False]), + pbc=torch.tensor([False, False, False]), ) requested_neighbor_lists = get_requested_neighbor_lists(exported) system = get_system_with_neighbor_lists(system, requested_neighbor_lists) diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py b/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py index 12bbe4c7a..b1fb4f1d3 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py @@ -28,7 +28,7 @@ def test_prediction_subset_elements(): types=torch.tensor([6, 6]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), - pbcs=torch.tensor([False, False, False]), + pbc=torch.tensor([False, False, False]), ) model( [system], @@ -57,7 +57,7 @@ def test_prediction_subset_atoms(): [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0]], ), cell=torch.zeros(3, 3), - pbcs=torch.tensor([False, False, False]), + pbc=torch.tensor([False, False, False]), ) energy_monomer = model( @@ -78,7 +78,7 @@ def test_prediction_subset_atoms(): ], ), cell=torch.zeros(3, 3), - pbcs=torch.tensor([False, False, False]), + pbc=torch.tensor([False, False, False]), ) selection_labels = metatensor.torch.Labels( @@ -122,7 +122,7 @@ def test_output_last_layer_features(): [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]], ), cell=torch.zeros(3, 3), - pbcs=torch.tensor([False, False, False]), + pbc=torch.tensor([False, False, False]), ) # last-layer features per atom: @@ -194,7 +194,7 @@ def test_output_per_atom(): [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]], ), cell=torch.zeros(3, 3), - pbcs=torch.tensor([False, False, False]), + pbc=torch.tensor([False, False, False]), ) outputs = model( diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py b/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py index 9e41046a3..6be11f582 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py @@ -26,7 +26,7 @@ def test_torchscript(): [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]] ), cell=torch.zeros(3, 3), - pbcs=torch.tensor([False, False, False]), + pbc=torch.tensor([False, False, False]), ) model( [system], @@ -53,7 +53,7 @@ def test_torchscript_with_identity(): [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]] ), cell=torch.zeros(3, 3), - pbcs=torch.tensor([False, False, False]), + pbc=torch.tensor([False, False, False]), ) model( [system], From cf62ee7b32f7ca5ee2f8cf1d5f49bf77386ec3f5 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 1 Nov 2024 12:33:06 +0100 Subject: [PATCH 08/12] Upgrade rascaline-torch --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f75f0c018..bec22e930 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ build-backend = "setuptools.build_meta" [project.optional-dependencies] soap-bpnn = [ - "rascaline-torch @ git+https://github.com/luthaf/rascaline@b70b19e#subdirectory=python/rascaline-torch", + "rascaline-torch @ git+https://github.com/luthaf/rascaline@5326b6e#subdirectory=python/rascaline-torch", ] alchemical-model = [ "torch_alchemical @ git+https://github.com/abmazitov/torch_alchemical.git@51ff519", @@ -68,7 +68,7 @@ pet = [ "pet @ git+https://github.com/lab-cosmo/pet@7eddb2e", ] gap = [ - "rascaline-torch @ git+https://github.com/luthaf/rascaline@b70b19e#subdirectory=python/rascaline-torch", + "rascaline-torch @ git+https://github.com/luthaf/rascaline@5326b6e#subdirectory=python/rascaline-torch", "skmatter", "metatensor-learn", "scipy", From 0ebcc9aba44ff12aa06b90b35725e7f8ae018796 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Sat, 2 Nov 2024 01:34:31 +0100 Subject: [PATCH 09/12] Fix `slice` argument name --- src/metatrain/experimental/pet/model.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/metatrain/experimental/pet/model.py b/src/metatrain/experimental/pet/model.py index 512dc319b..d0d21d73f 100644 --- a/src/metatrain/experimental/pet/model.py +++ b/src/metatrain/experimental/pet/model.py @@ -112,9 +112,7 @@ def forward( values=predictions, ) if selected_atoms is not None: - block = metatensor.torch.slice_block( - block, axis="samples", labels=selected_atoms - ) + block = metatensor.torch.slice_block(block, "samples", selected_atoms) output_tmap = TensorMap(keys=empty_labels, blocks=[block]) if not outputs[output_name].per_atom: output_tmap = metatensor.torch.sum_over_samples(output_tmap, "atom") From f881951b95d54e29b8c9aa00b188affb8608339e Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Tue, 5 Nov 2024 14:44:12 +0100 Subject: [PATCH 10/12] Add dataset information overview to the dev docs --- docs/src/dev-docs/dataset-information.rst | 38 +++++++++++++++++++++++ docs/src/dev-docs/index.rst | 1 + docs/src/dev-docs/utils/data/index.rst | 2 ++ 3 files changed, 41 insertions(+) create mode 100644 docs/src/dev-docs/dataset-information.rst diff --git a/docs/src/dev-docs/dataset-information.rst b/docs/src/dev-docs/dataset-information.rst new file mode 100644 index 000000000..1db07402e --- /dev/null +++ b/docs/src/dev-docs/dataset-information.rst @@ -0,0 +1,38 @@ +Dataset Information +=================== + +When working with ``metatrain``, you will most likely need to interact with some core +classes which are responsible for storing some information about datasets. All these +classes belong to the ``metatrain.utils.data`` module which can be found in the +:ref:`data` section of the developer documentation. + +These classes are: + +- :py:class:`metatrain.utils.data.DatasetInfo`: This class is responsible for storing + information about a dataset. It contains the length unit used in the dataset, the + atomic types present, as well as information about the dataset's targets as a + ``Dict[str, TargetInfo]`` object. The keys of this dictionary are the names of the + targets in the datasets (e.g., ``energy``, ``mtt::dipole``, etc.). + +- :py:class:`metatrain.utils.data.TargetInfo`: This class is responsible for storing + information about a target in a dataset. It contains the target's physical quantity, + the unit in which the target is expressed, and the ``layout`` of the target. The + ``layout`` is ``TensorMap`` object with zero samples which is used to exemplify + the metadata of each target. + +At the moment, only three types of layouts are supported: + +- scalar: This type of layout is used when the target is a scalar quantity. The + ``layout`` ``TensorMap`` object corresponding to a scalar must have one + ``TensorBlock`` and no ``components``. +- Cartesian tensor: This type of layout is used when the target is a Cartesian tensor. + The ``layout`` ``TensorMap`` object corresponding to a Cartesian tensor must have + one ``TensorBlock`` and as many ``components`` as the tensor's rank. These + components are named ``xyz`` for a tensor of rank 1 and ``xyz_1``, ``xyz_2``, and + so on for higher ranks. +- Spherical tensor: This type of layout is used when the target is a spherical tensor. + The ``layout`` ``TensorMap`` object corresponding to a spherical tensor can have + multiple blocks corresponding to different irreps (irreducible representations) of + the target. The ``keys`` of the ``TensorMap`` object must have the ``o3_lambda`` + and ``o3_sigma`` names, and each ``TensorBlock`` must have a single component named + ``o3_mu``. diff --git a/docs/src/dev-docs/index.rst b/docs/src/dev-docs/index.rst index 9dd337c6d..8dd3d91da 100644 --- a/docs/src/dev-docs/index.rst +++ b/docs/src/dev-docs/index.rst @@ -12,5 +12,6 @@ module. getting-started architecture-life-cycle new-architecture + dataset-information cli/index utils/index diff --git a/docs/src/dev-docs/utils/data/index.rst b/docs/src/dev-docs/utils/data/index.rst index 5f7f80970..a3c3c44c3 100644 --- a/docs/src/dev-docs/utils/data/index.rst +++ b/docs/src/dev-docs/utils/data/index.rst @@ -1,6 +1,8 @@ Data ==== +.. _data: + API for handling data in ``metatrain``. .. toctree:: From de9a28e480826a6bfc366f4317c42f855a81ebb0 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Tue, 5 Nov 2024 16:51:10 +0100 Subject: [PATCH 11/12] Fix docs --- docs/src/dev-docs/utils/data/index.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/dev-docs/utils/data/index.rst b/docs/src/dev-docs/utils/data/index.rst index a3c3c44c3..4a4034be1 100644 --- a/docs/src/dev-docs/utils/data/index.rst +++ b/docs/src/dev-docs/utils/data/index.rst @@ -1,8 +1,8 @@ +.. _data: + Data ==== -.. _data: - API for handling data in ``metatrain``. .. toctree:: From ce0ee06f44e8067e33bdcac7e7cc2460758e47fc Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Tue, 5 Nov 2024 17:22:18 +0100 Subject: [PATCH 12/12] Review changes --- src/metatrain/cli/train.py | 3 ++- src/metatrain/utils/data/dataset.py | 35 ++++++++++++++++++++--------- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/src/metatrain/cli/train.py b/src/metatrain/cli/train.py index 2d93cf7e1..e053f3810 100644 --- a/src/metatrain/cli/train.py +++ b/src/metatrain/cli/train.py @@ -236,7 +236,8 @@ def train_model( for key in intersecting_keys: if target_info_dict[key] != target_info_dict_single[key]: raise ValueError( - f"Target information for key {key} differs between training sets." + f"Target information for key {key} differs between training sets. " + f"Got {target_info_dict[key]} and {target_info_dict_single[key]}." ) target_info_dict.update(target_info_dict_single) diff --git a/src/metatrain/utils/data/dataset.py b/src/metatrain/utils/data/dataset.py index 79eb7b794..20c192503 100644 --- a/src/metatrain/utils/data/dataset.py +++ b/src/metatrain/utils/data/dataset.py @@ -33,9 +33,9 @@ def __init__( unit: Union[None, str] = "", ): # one of these will be set to True inside the _check_layout method - self.is_scalar = False - self.is_cartesian = False - self.is_spherical = False + self._is_scalar = False + self._is_cartesian = False + self._is_spherical = False self._check_layout(layout) @@ -43,10 +43,25 @@ def __init__( self.layout = layout self.unit = unit if unit is not None else "" + @property + def is_scalar(self) -> bool: + """Whether the target is a scalar.""" + return self._is_scalar + + @property + def is_cartesian(self) -> bool: + """Whether the target is a Cartesian tensor.""" + return self._is_cartesian + + @property + def is_spherical(self) -> bool: + """Whether the target is a spherical tensor.""" + return self._is_spherical + @property def gradients(self) -> List[str]: """Sorted and unique list of gradient names.""" - if self.is_scalar: + if self._is_scalar: return sorted(self.layout.block().gradients_list()) else: return [] @@ -102,14 +117,14 @@ def _check_layout(self, layout: TensorMap) -> None: ) components_first_block = layout.block(0).components if len(components_first_block) == 0: - self.is_scalar = True + self._is_scalar = True elif components_first_block[0].names[0].startswith("xyz"): - self.is_cartesian = True + self._is_cartesian = True elif ( len(components_first_block) == 1 and components_first_block[0].names[0] == "o3_mu" ): - self.is_spherical = True + self._is_spherical = True else: raise ValueError( "The layout ``TensorMap`` of a target should be " @@ -117,7 +132,7 @@ def _check_layout(self, layout: TensorMap) -> None: "the target could not be determined." ) - if self.is_scalar: + if self._is_scalar: if layout.keys.names != ["_"]: raise ValueError( "The layout ``TensorMap`` of a scalar target should have " @@ -136,7 +151,7 @@ def _check_layout(self, layout: TensorMap) -> None: "scalar targets. " f"Found '{gradient_name}' instead." ) - if self.is_cartesian: + if self._is_cartesian: if layout.keys.names != ["_"]: raise ValueError( "The layout ``TensorMap`` of a Cartesian tensor target should have " @@ -152,7 +167,7 @@ def _check_layout(self, layout: TensorMap) -> None: "Gradients of Cartesian tensor targets are not supported." ) - if self.is_spherical: + if self._is_spherical: if layout.keys.names != ["o3_lambda", "o3_sigma"]: raise ValueError( "The layout ``TensorMap`` of a spherical tensor target "