diff --git a/docs/src/dev-docs/dataset-information.rst b/docs/src/dev-docs/dataset-information.rst new file mode 100644 index 000000000..1db07402e --- /dev/null +++ b/docs/src/dev-docs/dataset-information.rst @@ -0,0 +1,38 @@ +Dataset Information +=================== + +When working with ``metatrain``, you will most likely need to interact with some core +classes which are responsible for storing some information about datasets. All these +classes belong to the ``metatrain.utils.data`` module which can be found in the +:ref:`data` section of the developer documentation. + +These classes are: + +- :py:class:`metatrain.utils.data.DatasetInfo`: This class is responsible for storing + information about a dataset. It contains the length unit used in the dataset, the + atomic types present, as well as information about the dataset's targets as a + ``Dict[str, TargetInfo]`` object. The keys of this dictionary are the names of the + targets in the datasets (e.g., ``energy``, ``mtt::dipole``, etc.). + +- :py:class:`metatrain.utils.data.TargetInfo`: This class is responsible for storing + information about a target in a dataset. It contains the target's physical quantity, + the unit in which the target is expressed, and the ``layout`` of the target. The + ``layout`` is ``TensorMap`` object with zero samples which is used to exemplify + the metadata of each target. + +At the moment, only three types of layouts are supported: + +- scalar: This type of layout is used when the target is a scalar quantity. The + ``layout`` ``TensorMap`` object corresponding to a scalar must have one + ``TensorBlock`` and no ``components``. +- Cartesian tensor: This type of layout is used when the target is a Cartesian tensor. + The ``layout`` ``TensorMap`` object corresponding to a Cartesian tensor must have + one ``TensorBlock`` and as many ``components`` as the tensor's rank. These + components are named ``xyz`` for a tensor of rank 1 and ``xyz_1``, ``xyz_2``, and + so on for higher ranks. +- Spherical tensor: This type of layout is used when the target is a spherical tensor. + The ``layout`` ``TensorMap`` object corresponding to a spherical tensor can have + multiple blocks corresponding to different irreps (irreducible representations) of + the target. The ``keys`` of the ``TensorMap`` object must have the ``o3_lambda`` + and ``o3_sigma`` names, and each ``TensorBlock`` must have a single component named + ``o3_mu``. diff --git a/docs/src/dev-docs/index.rst b/docs/src/dev-docs/index.rst index 9dd337c6d..8dd3d91da 100644 --- a/docs/src/dev-docs/index.rst +++ b/docs/src/dev-docs/index.rst @@ -12,5 +12,6 @@ module. getting-started architecture-life-cycle new-architecture + dataset-information cli/index utils/index diff --git a/docs/src/dev-docs/utils/data/index.rst b/docs/src/dev-docs/utils/data/index.rst index 5f7f80970..4a4034be1 100644 --- a/docs/src/dev-docs/utils/data/index.rst +++ b/docs/src/dev-docs/utils/data/index.rst @@ -1,3 +1,5 @@ +.. _data: + Data ==== diff --git a/src/metatrain/cli/eval.py b/src/metatrain/cli/eval.py index 1a9de5a42..bf17745dd 100644 --- a/src/metatrain/cli/eval.py +++ b/src/metatrain/cli/eval.py @@ -15,7 +15,6 @@ from ..utils.data import ( Dataset, TargetInfo, - TargetInfoDict, collate_fn, read_systems, read_targets, @@ -159,7 +158,7 @@ def _concatenate_tensormaps( def _eval_targets( model: Union[MetatensorAtomisticModel, torch.jit._script.RecursiveScriptModule], dataset: Union[Dataset, torch.utils.data.Subset], - options: TargetInfoDict, + options: Dict[str, TargetInfo], return_predictions: bool, check_consistency: bool = False, ) -> Optional[Dict[str, TensorMap]]: @@ -335,17 +334,17 @@ def eval_model( # (but we don't/can't calculate RMSEs) # TODO: allow the user to specify which outputs to evaluate eval_targets = {} - eval_info_dict = TargetInfoDict() - gradients = ["positions"] - if all(not torch.all(system.cell == 0) for system in eval_systems): - # only add strain if all structures have cells - gradients.append("strain") + eval_info_dict = {} + do_strain_grad = all( + not torch.all(system.cell == 0) for system in eval_systems + ) + layout = _get_energy_layout(do_strain_grad) # TODO: layout from the user for key in model.capabilities().outputs.keys(): eval_info_dict[key] = TargetInfo( quantity=model.capabilities().outputs[key].quantity, unit=model.capabilities().outputs[key].unit, - per_atom=False, # TODO: allow the user to specify this - gradients=gradients, + # TODO: allow the user to specify whether per-atom or not + layout=layout, ) eval_dataset = Dataset.from_dict({"system": eval_systems, **eval_targets}) @@ -368,3 +367,60 @@ def eval_model( capabilities=model.capabilities(), predictions=predictions, ) + + +def _get_energy_layout(strain_gradient: bool) -> TensorMap: + block = TensorBlock( + # float64: otherwise metatensor can't serialize + values=torch.empty(0, 1, dtype=torch.float64), + samples=Labels( + names=["system"], + values=torch.empty((0, 1), dtype=torch.int32), + ), + components=[], + properties=Labels.range("energy", 1), + ) + position_gradient_block = TensorBlock( + # float64: otherwise metatensor can't serialize + values=torch.empty(0, 3, 1, dtype=torch.float64), + samples=Labels( + names=["sample", "atom"], + values=torch.empty((0, 2), dtype=torch.int32), + ), + components=[ + Labels( + names=["xyz"], + values=torch.arange(3, dtype=torch.int32).reshape(-1, 1), + ), + ], + properties=Labels.range("energy", 1), + ) + block.add_gradient("positions", position_gradient_block) + + if strain_gradient: + strain_gradient_block = TensorBlock( + # float64: otherwise metatensor can't serialize + values=torch.empty(0, 3, 3, 1, dtype=torch.float64), + samples=Labels( + names=["sample", "atom"], + values=torch.empty((0, 2), dtype=torch.int32), + ), + components=[ + Labels( + names=["xyz_1"], + values=torch.arange(3, dtype=torch.int32).reshape(-1, 1), + ), + Labels( + names=["xyz_2"], + values=torch.arange(3, dtype=torch.int32).reshape(-1, 1), + ), + ], + properties=Labels.range("energy", 1), + ) + block.add_gradient("strain", strain_gradient_block) + + energy_layout = TensorMap( + keys=Labels.single(), + blocks=[block], + ) + return energy_layout diff --git a/src/metatrain/cli/train.py b/src/metatrain/cli/train.py index 407c798f4..e053f3810 100644 --- a/src/metatrain/cli/train.py +++ b/src/metatrain/cli/train.py @@ -22,7 +22,7 @@ ) from ..utils.data import ( DatasetInfo, - TargetInfoDict, + TargetInfo, get_atomic_types, get_dataset, get_stats, @@ -228,11 +228,18 @@ def train_model( options["training_set"] = expand_dataset_config(options["training_set"]) train_datasets = [] - target_infos = TargetInfoDict() - for train_options in options["training_set"]: - dataset, target_info_dict = get_dataset(train_options) + target_info_dict: Dict[str, TargetInfo] = {} + for train_options in options["training_set"]: # loop over training sets + dataset, target_info_dict_single = get_dataset(train_options) train_datasets.append(dataset) - target_infos.update(target_info_dict) + intersecting_keys = target_info_dict.keys() & target_info_dict_single.keys() + for key in intersecting_keys: + if target_info_dict[key] != target_info_dict_single[key]: + raise ValueError( + f"Target information for key {key} differs between training sets. " + f"Got {target_info_dict[key]} and {target_info_dict_single[key]}." + ) + target_info_dict.update(target_info_dict_single) train_size = 1.0 @@ -321,7 +328,7 @@ def train_model( dataset_info = DatasetInfo( length_unit=options["training_set"][0]["systems"]["length_unit"], atomic_types=atomic_types, - targets=target_infos, + targets=target_info_dict, ) ########################### diff --git a/src/metatrain/experimental/alchemical_model/tests/test_exported.py b/src/metatrain/experimental/alchemical_model/tests/test_exported.py index 8b06c2d74..04874e6c3 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_exported.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_exported.py @@ -3,11 +3,12 @@ from metatensor.torch.atomistic import ModelEvaluationOptions, System from metatrain.experimental.alchemical_model import AlchemicalModel -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) +from metatrain.utils.testing import energy_layout from . import MODEL_HYPERS @@ -22,7 +23,9 @@ def test_to(device, dtype): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = AlchemicalModel(MODEL_HYPERS, dataset_info).to(dtype=dtype) exported = model.export() diff --git a/src/metatrain/experimental/alchemical_model/tests/test_functionality.py b/src/metatrain/experimental/alchemical_model/tests/test_functionality.py index 9f998f925..6b2eae8df 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_functionality.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_functionality.py @@ -2,11 +2,12 @@ from metatensor.torch.atomistic import ModelEvaluationOptions, System from metatrain.experimental.alchemical_model import AlchemicalModel -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) +from metatrain.utils.testing import energy_layout from . import MODEL_HYPERS @@ -18,7 +19,9 @@ def test_prediction_subset_elements(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = AlchemicalModel(MODEL_HYPERS, dataset_info) diff --git a/src/metatrain/experimental/alchemical_model/tests/test_invariance.py b/src/metatrain/experimental/alchemical_model/tests/test_invariance.py index 9d9a84dd9..c8a42676d 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_invariance.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_invariance.py @@ -5,11 +5,12 @@ from metatensor.torch.atomistic import ModelEvaluationOptions, systems_to_torch from metatrain.experimental.alchemical_model import AlchemicalModel -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) +from metatrain.utils.testing import energy_layout from . import DATASET_PATH, MODEL_HYPERS @@ -20,7 +21,9 @@ def test_rotational_invariance(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = AlchemicalModel(MODEL_HYPERS, dataset_info) diff --git a/src/metatrain/experimental/alchemical_model/tests/test_regression.py b/src/metatrain/experimental/alchemical_model/tests/test_regression.py index 648c91c18..0d82379c3 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_regression.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_regression.py @@ -13,11 +13,11 @@ read_systems, read_targets, ) -from metatrain.utils.data.dataset import TargetInfoDict from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) +from metatrain.utils.testing import energy_layout from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS @@ -31,8 +31,8 @@ def test_regression_init(): """Perform a regression test on the model at initialization""" - targets = TargetInfoDict() - targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV") + targets = {} + targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV", layout=energy_layout) dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets diff --git a/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py b/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py index 7b560e786..9f17dd9d0 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py @@ -14,8 +14,9 @@ from metatrain.experimental.alchemical_model.utils import ( systems_to_torch_alchemical_batch, ) -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict, read_systems +from metatrain.utils.data import DatasetInfo, TargetInfo, read_systems from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.testing import energy_layout from . import MODEL_HYPERS, QM9_DATASET_PATH @@ -71,7 +72,9 @@ def test_alchemical_model_inference(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=unique_numbers, - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) alchemical_model = AlchemicalModel(MODEL_HYPERS, dataset_info) diff --git a/src/metatrain/experimental/alchemical_model/tests/test_torchscript.py b/src/metatrain/experimental/alchemical_model/tests/test_torchscript.py index 33e0b9e9f..30d5511d7 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_torchscript.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_torchscript.py @@ -1,7 +1,8 @@ import torch from metatrain.experimental.alchemical_model import AlchemicalModel -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.testing import energy_layout from . import MODEL_HYPERS @@ -12,7 +13,9 @@ def test_torchscript(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = AlchemicalModel(MODEL_HYPERS, dataset_info) @@ -25,7 +28,9 @@ def test_torchscript_save_load(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = AlchemicalModel(MODEL_HYPERS, dataset_info) torch.jit.save( diff --git a/src/metatrain/experimental/alchemical_model/trainer.py b/src/metatrain/experimental/alchemical_model/trainer.py index 3ed190c00..4127e7036 100644 --- a/src/metatrain/experimental/alchemical_model/trainer.py +++ b/src/metatrain/experimental/alchemical_model/trainer.py @@ -9,7 +9,6 @@ from ...utils.data import ( CombinedDataLoader, Dataset, - TargetInfoDict, check_datasets, collate_fn, get_all_targets, @@ -247,12 +246,7 @@ def train( predictions = evaluate_model( model, systems, - TargetInfoDict( - **{ - key: model.dataset_info.targets[key] - for key in targets.keys() - } - ), + {key: model.dataset_info.targets[key] for key in targets.keys()}, is_training=True, ) @@ -295,12 +289,7 @@ def train( predictions = evaluate_model( model, systems, - TargetInfoDict( - **{ - key: model.dataset_info.targets[key] - for key in targets.keys() - } - ), + {key: model.dataset_info.targets[key] for key in targets.keys()}, is_training=False, ) diff --git a/src/metatrain/experimental/gap/tests/test_errors.py b/src/metatrain/experimental/gap/tests/test_errors.py index b0359bd80..791e061c6 100644 --- a/src/metatrain/experimental/gap/tests/test_errors.py +++ b/src/metatrain/experimental/gap/tests/test_errors.py @@ -12,10 +12,10 @@ Dataset, DatasetInfo, TargetInfo, - TargetInfoDict, read_systems, read_targets, ) +from metatrain.utils.testing import energy_force_layout from . import DATASET_ETHANOL_PATH, DEFAULT_HYPERS @@ -61,9 +61,11 @@ def test_ethanol_regression_train_and_invariance(): hypers = copy.deepcopy(DEFAULT_HYPERS) hypers["model"]["krr"]["num_sparse_points"] = 30 - target_info_dict = TargetInfoDict( - energy=TargetInfo(quantity="energy", unit="kcal/mol", gradients=["positions"]) - ) + target_info_dict = { + "energy": TargetInfo( + quantity="energy", unit="kcal/mol", layout=energy_force_layout + ) + } dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict diff --git a/src/metatrain/experimental/gap/tests/test_regression.py b/src/metatrain/experimental/gap/tests/test_regression.py index 81212353c..a3d703e76 100644 --- a/src/metatrain/experimental/gap/tests/test_regression.py +++ b/src/metatrain/experimental/gap/tests/test_regression.py @@ -8,8 +8,9 @@ from omegaconf import OmegaConf from metatrain.experimental.gap import GAP, Trainer -from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo from metatrain.utils.data.readers import read_systems, read_targets +from metatrain.utils.testing import energy_force_layout, energy_layout from . import DATASET_ETHANOL_PATH, DATASET_PATH, DEFAULT_HYPERS @@ -25,8 +26,8 @@ def test_regression_init(): """Perform a regression test on the model at initialization""" - targets = TargetInfoDict() - targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV") + targets = {} + targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV", layout=energy_layout) dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets @@ -57,8 +58,10 @@ def test_regression_train_and_invariance(): targets, _ = read_targets(OmegaConf.create(conf)) dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) - target_info_dict = TargetInfoDict() - target_info_dict["mtt::U0"] = TargetInfo(quantity="energy", unit="eV") + target_info_dict = {} + target_info_dict["mtt::U0"] = TargetInfo( + quantity="energy", unit="eV", layout=energy_layout + ) dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict @@ -138,9 +141,11 @@ def test_ethanol_regression_train_and_invariance(): hypers = copy.deepcopy(DEFAULT_HYPERS) hypers["model"]["krr"]["num_sparse_points"] = 900 - target_info_dict = TargetInfoDict( - energy=TargetInfo(quantity="energy", unit="kcal/mol", gradients=["positions"]) - ) + target_info_dict = { + "energy": TargetInfo( + quantity="energy", unit="kcal/mol", layout=energy_force_layout + ) + } dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict diff --git a/src/metatrain/experimental/gap/tests/test_torchscript.py b/src/metatrain/experimental/gap/tests/test_torchscript.py index 967a83353..e7d1cbc2f 100644 --- a/src/metatrain/experimental/gap/tests/test_torchscript.py +++ b/src/metatrain/experimental/gap/tests/test_torchscript.py @@ -2,8 +2,9 @@ from omegaconf import OmegaConf from metatrain.experimental.gap import GAP, Trainer -from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo from metatrain.utils.data.readers import read_systems, read_targets +from metatrain.utils.testing import energy_layout from . import DATASET_PATH, DEFAULT_HYPERS @@ -13,8 +14,10 @@ def test_torchscript(): """Tests that the model can be jitted.""" - target_info_dict = TargetInfoDict() - target_info_dict["mtt::U0"] = TargetInfo(quantity="energy", unit="eV") + target_info_dict = {} + target_info_dict["mtt::U0"] = TargetInfo( + quantity="energy", unit="eV", layout=energy_layout + ) dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict @@ -34,8 +37,6 @@ def test_torchscript(): targets, _ = read_targets(OmegaConf.create(conf)) systems = read_systems(DATASET_PATH) - # for system in systems: - # system.types = torch.ones(len(system.types), dtype=torch.int32) dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) hypers = DEFAULT_HYPERS.copy() @@ -64,8 +65,8 @@ def test_torchscript(): def test_torchscript_save(): """Tests that the model can be jitted and saved.""" - targets = TargetInfoDict() - targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV") + targets = {} + targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV", layout=energy_layout) dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets diff --git a/src/metatrain/experimental/pet/tests/test_exported.py b/src/metatrain/experimental/pet/tests/test_exported.py index 5ed9ecf6e..e2cc2e0f4 100644 --- a/src/metatrain/experimental/pet/tests/test_exported.py +++ b/src/metatrain/experimental/pet/tests/test_exported.py @@ -11,12 +11,13 @@ from metatrain.experimental.pet import PET as WrappedPET from metatrain.utils.architectures import get_default_hypers -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo from metatrain.utils.export import export from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) +from metatrain.utils.testing import energy_layout DEFAULT_HYPERS = get_default_hypers("experimental.pet") @@ -32,7 +33,9 @@ def test_to(device): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) ARCHITECTURAL_HYPERS = Hypers(model.hypers) diff --git a/src/metatrain/experimental/pet/tests/test_functionality.py b/src/metatrain/experimental/pet/tests/test_functionality.py index b0f5771ea..2ebd3841f 100644 --- a/src/metatrain/experimental/pet/tests/test_functionality.py +++ b/src/metatrain/experimental/pet/tests/test_functionality.py @@ -18,12 +18,13 @@ from metatrain.experimental.pet import PET as WrappedPET from metatrain.utils.architectures import get_default_hypers -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo from metatrain.utils.jsonschema import validate from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) +from metatrain.utils.testing import energy_layout DEFAULT_HYPERS = get_default_hypers("experimental.pet") @@ -65,7 +66,9 @@ def test_prediction(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) ARCHITECTURAL_HYPERS = Hypers(model.hypers) @@ -116,7 +119,9 @@ def test_per_atom_predictions_functionality(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) ARCHITECTURAL_HYPERS = Hypers(model.hypers) @@ -168,7 +173,9 @@ def test_selected_atoms_functionality(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) ARCHITECTURAL_HYPERS = Hypers(model.hypers) diff --git a/src/metatrain/experimental/pet/tests/test_pet_compatibility.py b/src/metatrain/experimental/pet/tests/test_pet_compatibility.py index 4a2bdd71e..1a83371af 100644 --- a/src/metatrain/experimental/pet/tests/test_pet_compatibility.py +++ b/src/metatrain/experimental/pet/tests/test_pet_compatibility.py @@ -17,8 +17,9 @@ from metatrain.experimental.pet import PET as WrappedPET from metatrain.experimental.pet.utils import systems_to_batch_dict from metatrain.utils.architectures import get_default_hypers -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.testing import energy_layout from . import DATASET_PATH @@ -97,7 +98,9 @@ def test_predictions_compatibility(cutoff): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=structure.numbers, - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) capabilities = ModelCapabilities( length_unit="Angstrom", diff --git a/src/metatrain/experimental/pet/tests/test_torchscript.py b/src/metatrain/experimental/pet/tests/test_torchscript.py index df0584cd3..15adc95a8 100644 --- a/src/metatrain/experimental/pet/tests/test_torchscript.py +++ b/src/metatrain/experimental/pet/tests/test_torchscript.py @@ -4,7 +4,8 @@ from metatrain.experimental.pet import PET as WrappedPET from metatrain.utils.architectures import get_default_hypers -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.testing import energy_layout DEFAULT_HYPERS = get_default_hypers("experimental.pet") @@ -16,7 +17,9 @@ def test_torchscript(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) ARCHITECTURAL_HYPERS = Hypers(model.hypers) @@ -31,7 +34,9 @@ def test_torchscript_save_load(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) ARCHITECTURAL_HYPERS = Hypers(model.hypers) diff --git a/src/metatrain/experimental/soap_bpnn/model.py b/src/metatrain/experimental/soap_bpnn/model.py index 556f3ef52..2684e34db 100644 --- a/src/metatrain/experimental/soap_bpnn/model.py +++ b/src/metatrain/experimental/soap_bpnn/model.py @@ -204,7 +204,11 @@ def restart(self, dataset_info: DatasetInfo) -> "SoapBpnn": new_atomic_types = [ at for at in merged_info.atomic_types if at not in self.atomic_types ] - new_targets = merged_info.targets - self.dataset_info.targets + new_targets = { + key: value + for key, value in merged_info.targets.items() + if key not in self.dataset_info.targets + } if len(new_atomic_types) > 0: raise ValueError( diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_continue.py b/src/metatrain/experimental/soap_bpnn/tests/test_continue.py index fc732897e..50fa7118a 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_continue.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_continue.py @@ -6,8 +6,9 @@ from omegaconf import OmegaConf from metatrain.experimental.soap_bpnn import SoapBpnn, Trainer -from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo from metatrain.utils.data.readers import read_systems, read_targets +from metatrain.utils.testing import energy_layout from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS @@ -22,8 +23,10 @@ def test_continue(monkeypatch, tmp_path): systems = read_systems(DATASET_PATH) systems = [system.to(torch.float32) for system in systems] - target_info_dict = TargetInfoDict() - target_info_dict["mtt::U0"] = TargetInfo(quantity="energy", unit="eV") + target_info_dict = {} + target_info_dict["mtt::U0"] = TargetInfo( + quantity="energy", unit="eV", layout=energy_layout + ) dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_exported.py b/src/metatrain/experimental/soap_bpnn/tests/test_exported.py index 349e89cfd..421cc422a 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_exported.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_exported.py @@ -3,11 +3,12 @@ from metatensor.torch.atomistic import ModelEvaluationOptions, System from metatrain.experimental.soap_bpnn import SoapBpnn -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) +from metatrain.utils.testing import energy_layout from . import MODEL_HYPERS @@ -22,7 +23,9 @@ def test_to(device, dtype): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = SoapBpnn(MODEL_HYPERS, dataset_info).to(dtype=dtype) exported = model.export() diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py b/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py index b1fb4f1d3..2e7cde5e7 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py @@ -7,7 +7,8 @@ from metatrain.experimental.soap_bpnn import SoapBpnn from metatrain.utils.architectures import check_architecture_options -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.testing import energy_layout from . import DEFAULT_HYPERS, MODEL_HYPERS @@ -19,7 +20,9 @@ def test_prediction_subset_elements(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = SoapBpnn(MODEL_HYPERS, dataset_info) @@ -43,7 +46,9 @@ def test_prediction_subset_atoms(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = SoapBpnn(MODEL_HYPERS, dataset_info) @@ -111,7 +116,9 @@ def test_output_last_layer_features(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = SoapBpnn(MODEL_HYPERS, dataset_info) @@ -183,7 +190,9 @@ def test_output_per_atom(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = SoapBpnn(MODEL_HYPERS, dataset_info) diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_invariance.py b/src/metatrain/experimental/soap_bpnn/tests/test_invariance.py index 2b5835b74..92c96767d 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_invariance.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_invariance.py @@ -5,7 +5,8 @@ from metatensor.torch.atomistic import systems_to_torch from metatrain.experimental.soap_bpnn import SoapBpnn -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.testing import energy_layout from . import DATASET_PATH, MODEL_HYPERS @@ -16,7 +17,9 @@ def test_rotational_invariance(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = SoapBpnn(MODEL_HYPERS, dataset_info) diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_regression.py b/src/metatrain/experimental/soap_bpnn/tests/test_regression.py index 7b4161ddb..35f691051 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_regression.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_regression.py @@ -6,8 +6,9 @@ from omegaconf import OmegaConf from metatrain.experimental.soap_bpnn import SoapBpnn, Trainer -from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo from metatrain.utils.data.readers import read_systems, read_targets +from metatrain.utils.testing import energy_layout from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS @@ -21,8 +22,8 @@ def test_regression_init(): """Perform a regression test on the model at initialization""" - targets = TargetInfoDict() - targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV") + targets = {} + targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV", layout=energy_layout) dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py b/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py index 6be11f582..53a7e161e 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py @@ -4,7 +4,8 @@ from metatensor.torch.atomistic import System from metatrain.experimental.soap_bpnn import SoapBpnn -from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.testing import energy_layout from . import MODEL_HYPERS @@ -15,7 +16,9 @@ def test_torchscript(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = SoapBpnn(MODEL_HYPERS, dataset_info) model = torch.jit.script(model) @@ -40,7 +43,9 @@ def test_torchscript_with_identity(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) hypers = copy.deepcopy(MODEL_HYPERS) hypers["bpnn"]["layernorm"] = False @@ -67,7 +72,9 @@ def test_torchscript_save_load(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")), + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = SoapBpnn(MODEL_HYPERS, dataset_info) torch.jit.save( diff --git a/src/metatrain/experimental/soap_bpnn/trainer.py b/src/metatrain/experimental/soap_bpnn/trainer.py index aed858bcd..810d9e471 100644 --- a/src/metatrain/experimental/soap_bpnn/trainer.py +++ b/src/metatrain/experimental/soap_bpnn/trainer.py @@ -8,8 +8,7 @@ from torch.utils.data import DataLoader, DistributedSampler from ...utils.additive import remove_additive -from ...utils.data import CombinedDataLoader, Dataset, TargetInfoDict, collate_fn -from ...utils.data.extract_targets import get_targets_dict +from ...utils.data import CombinedDataLoader, Dataset, collate_fn from ...utils.distributed.distributed_data_parallel import DistributedDataParallel from ...utils.distributed.slurm import DistributedEnvironment from ...utils.evaluate_model import evaluate_model @@ -182,9 +181,7 @@ def train( val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) # Extract all the possible outputs and their gradients: - train_targets = get_targets_dict( - train_datasets, (model.module if is_distributed else model).dataset_info - ) + train_targets = (model.module if is_distributed else model).dataset_info.targets outputs_list = [] for target_name, target_info in train_targets.items(): outputs_list.append(target_name) @@ -270,9 +267,7 @@ def train( predictions = evaluate_model( model, systems, - TargetInfoDict( - **{key: train_targets[key] for key in targets.keys()} - ), + {key: train_targets[key] for key in targets.keys()}, is_training=True, ) @@ -325,9 +320,7 @@ def train( predictions = evaluate_model( model, systems, - TargetInfoDict( - **{key: train_targets[key] for key in targets.keys()} - ), + {key: train_targets[key] for key in targets.keys()}, is_training=False, ) diff --git a/src/metatrain/utils/additive/remove.py b/src/metatrain/utils/additive/remove.py index 4235af1fc..899a96a07 100644 --- a/src/metatrain/utils/additive/remove.py +++ b/src/metatrain/utils/additive/remove.py @@ -1,12 +1,12 @@ import warnings -from typing import Dict, List, Union +from typing import Dict, List import metatensor.torch import torch from metatensor.torch import TensorMap from metatensor.torch.atomistic import System -from ..data import TargetInfo, TargetInfoDict +from ..data import TargetInfo from ..evaluate_model import evaluate_model @@ -14,7 +14,7 @@ def remove_additive( systems: List[System], targets: Dict[str, TensorMap], additive_model: torch.nn.Module, - target_info_dict: Union[Dict[str, TargetInfo], TargetInfoDict], + target_info_dict: Dict[str, TargetInfo], ): """Remove an additive contribution from the training targets. @@ -35,7 +35,7 @@ def remove_additive( additive_contribution = evaluate_model( additive_model, systems, - TargetInfoDict(**{key: target_info_dict[key] for key in targets.keys()}), + {key: target_info_dict[key] for key in targets.keys()}, is_training=False, # we don't need any gradients w.r.t. any parameters ) diff --git a/src/metatrain/utils/data/__init__.py b/src/metatrain/utils/data/__init__.py index 0aa128433..e93621947 100644 --- a/src/metatrain/utils/data/__init__.py +++ b/src/metatrain/utils/data/__init__.py @@ -1,7 +1,6 @@ from .dataset import ( # noqa: F401 Dataset, TargetInfo, - TargetInfoDict, DatasetInfo, get_atomic_types, get_all_targets, @@ -21,5 +20,4 @@ from .writers import write_predictions # noqa: F401 from .combine_dataloaders import CombinedDataLoader # noqa: F401 from .system_to_ase import system_to_ase # noqa: F401 -from .extract_targets import get_targets_dict # noqa: F401 from .get_dataset import get_dataset # noqa: F401 diff --git a/src/metatrain/utils/data/dataset.py b/src/metatrain/utils/data/dataset.py index cd6a68ab2..20c192503 100644 --- a/src/metatrain/utils/data/dataset.py +++ b/src/metatrain/utils/data/dataset.py @@ -1,8 +1,8 @@ import math import warnings -from collections import UserDict -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Tuple, Union +import metatensor.torch import numpy as np from metatensor.learn.data import Dataset, group_and_join from metatensor.torch import TensorMap @@ -15,40 +15,66 @@ class TargetInfo: """A class that contains information about a target. - :param quantity: The quantity of the target. + :param quantity: The physical quantity of the target (e.g., "energy"). + :param layout: The layout of the target, as a ``TensorMap`` with 0 samples. + This ``TensorMap`` will be used to retrieve the names of + the ``samples``, as well as the ``components`` and ``properties`` of the + target and their gradients. For example, this allows to infer the type of + the target (scalar, Cartesian tensor, spherical tensor), whether it is per + atom, the names of its gradients, etc. :param unit: The unit of the target. If :py:obj:`None` the ``unit`` will be set to an empty string ``""``. - :param per_atom: Whether the target is a per-atom quantity. - :param gradients: List containing the gradient names of the target that are present - in the target. Examples are ``"positions"`` or ``"strain"``. ``gradients`` will - be stored as a sorted list of **unique** gradients. """ def __init__( self, quantity: str, + layout: TensorMap, unit: Union[None, str] = "", - per_atom: bool = False, - gradients: Optional[List[str]] = None, ): - self.quantity = quantity + # one of these will be set to True inside the _check_layout method + self._is_scalar = False + self._is_cartesian = False + self._is_spherical = False + + self._check_layout(layout) + + self.quantity = quantity # float64: otherwise metatensor can't serialize + self.layout = layout self.unit = unit if unit is not None else "" - self.per_atom = per_atom - self._gradients = set(gradients) if gradients is not None else set() + + @property + def is_scalar(self) -> bool: + """Whether the target is a scalar.""" + return self._is_scalar + + @property + def is_cartesian(self) -> bool: + """Whether the target is a Cartesian tensor.""" + return self._is_cartesian + + @property + def is_spherical(self) -> bool: + """Whether the target is a spherical tensor.""" + return self._is_spherical @property def gradients(self) -> List[str]: """Sorted and unique list of gradient names.""" - return sorted(self._gradients) + if self._is_scalar: + return sorted(self.layout.block().gradients_list()) + else: + return [] - @gradients.setter - def gradients(self, value: List[str]): - self._gradients = set(value) + @property + def per_atom(self) -> bool: + """Whether the target is per atom.""" + return "atom" in self.layout.block(0).samples.names def __repr__(self): return ( f"TargetInfo(quantity={self.quantity!r}, unit={self.unit!r}, " - f"per_atom={self.per_atom!r}, gradients={self.gradients!r})" + f"layout={self.layout!r})" ) def __eq__(self, other): @@ -60,109 +86,126 @@ def __eq__(self, other): return ( self.quantity == other.quantity and self.unit == other.unit - and self.per_atom == other.per_atom - and self._gradients == other._gradients + and metatensor.torch.equal(self.layout, other.layout) ) - def copy(self) -> "TargetInfo": - """Return a shallow copy of the TargetInfo.""" - return TargetInfo( - quantity=self.quantity, - unit=self.unit, - per_atom=self.per_atom, - gradients=self.gradients.copy(), - ) - - def update(self, other: "TargetInfo") -> None: - """Update this instance with the union of itself and ``other``. - - :raises ValueError: If ``quantity``, ``unit`` or ``per_atom`` do not match. - """ - if self.quantity != other.quantity: - raise ValueError( - f"Can't update TargetInfo with a different `quantity`: " - f"({self.quantity} != {other.quantity})" - ) + def _check_layout(self, layout: TensorMap) -> None: + """Check that the layout is a valid layout.""" + + # examine basic properties of all blocks + for block in layout.blocks(): + for sample_name in block.samples.names: + if sample_name not in ["system", "atom"]: + raise ValueError( + "The layout ``TensorMap`` of a target should only have samples " + "named 'system' or 'atom', but found " + f"'{sample_name}' instead." + ) + if len(block.values) != 0: + raise ValueError( + "The layout ``TensorMap`` of a target should have 0 " + f"samples, but found {len(block.values)} samples." + ) - if self.unit != other.unit: - raise ValueError( - f"Can't update TargetInfo with a different `unit`: " - f"({self.unit} != {other.unit})" - ) + # examine the components of the first block to decide whether this is + # a scalar, a Cartesian tensor or a spherical tensor - if self.per_atom != other.per_atom: + if len(layout) == 0: raise ValueError( - f"Can't update TargetInfo with a different `per_atom` property: " - f"({self.per_atom} != {other.per_atom})" + "The layout ``TensorMap`` of a target should have at least one " + "block, but found 0 blocks." ) - - self.gradients = self.gradients + other.gradients - - def union(self, other: "TargetInfo") -> "TargetInfo": - """Return the union of this instance with ``other``.""" - new = self.copy() - new.update(other) - return new - - -class TargetInfoDict(UserDict): - """ - A custom dictionary class for storing and managing ``TargetInfo`` instances. - - The subclass handles the update of :py:class:`TargetInfo` if a ``key`` is already - present. - """ - - # We use a `UserDict` with special methods because a normal dict does not support - # the update of nested instances. - def __setitem__(self, key, value): - if not isinstance(value, TargetInfo): - raise ValueError("value to set is not a `TargetInfo` instance") - if key in self: - self[key].update(value) - else: - super().__setitem__(key, value) - - def __and__(self, other: "TargetInfoDict") -> "TargetInfoDict": - return self.intersection(other) - - def __sub__(self, other: "TargetInfoDict") -> "TargetInfoDict": - return self.difference(other) - - def union(self, other: "TargetInfoDict") -> "TargetInfoDict": - """Union of this instance with ``other``.""" - new = self.copy() - new.update(other) - return new - - def intersection(self, other: "TargetInfoDict") -> "TargetInfoDict": - """Intersection of the the two instances as a new ``TargetInfoDict``. - - (i.e. all elements that are in both sets.) - - :raises ValueError: If intersected items with the same key are not the same. - """ - new_keys = self.keys() & other.keys() - - self_intersect = TargetInfoDict(**{key: self[key] for key in new_keys}) - other_intersect = TargetInfoDict(**{key: other[key] for key in new_keys}) - - if self_intersect == other_intersect: - return self_intersect + components_first_block = layout.block(0).components + if len(components_first_block) == 0: + self._is_scalar = True + elif components_first_block[0].names[0].startswith("xyz"): + self._is_cartesian = True + elif ( + len(components_first_block) == 1 + and components_first_block[0].names[0] == "o3_mu" + ): + self._is_spherical = True else: raise ValueError( - "Intersected items with the same key are not the same. Intersected " - f"keys are {','.join(new_keys)}" + "The layout ``TensorMap`` of a target should be " + "either scalars, Cartesian tensors or spherical tensors. The type of " + "the target could not be determined." ) - def difference(self, other: "TargetInfoDict") -> "TargetInfoDict": - """Difference of two instances as a new ``TargetInfoDict``. - - (i.e. all elements that are in this set but not in the other.) - """ + if self._is_scalar: + if layout.keys.names != ["_"]: + raise ValueError( + "The layout ``TensorMap`` of a scalar target should have " + "a single key sample named '_'." + ) + if len(layout.blocks()) != 1: + raise ValueError( + "The layout ``TensorMap`` of a scalar target should have " + "a single block." + ) + gradients_names = layout.block(0).gradients_list() + for gradient_name in gradients_names: + if gradient_name not in ["positions", "strain"]: + raise ValueError( + "Only `positions` and `strain` gradients are supported for " + "scalar targets. " + f"Found '{gradient_name}' instead." + ) + if self._is_cartesian: + if layout.keys.names != ["_"]: + raise ValueError( + "The layout ``TensorMap`` of a Cartesian tensor target should have " + "a single key sample named '_'." + ) + if len(layout.blocks()) != 1: + raise ValueError( + "The layout ``TensorMap`` of a Cartesian tensor target should have " + "a single block." + ) + if len(layout.block(0).gradients_list()) > 0: + raise ValueError( + "Gradients of Cartesian tensor targets are not supported." + ) - new_keys = self.keys() - other.keys() - return TargetInfoDict(**{key: self[key] for key in new_keys}) + if self._is_spherical: + if layout.keys.names != ["o3_lambda", "o3_sigma"]: + raise ValueError( + "The layout ``TensorMap`` of a spherical tensor target " + "should have two keys named 'o3_lambda' and 'o3_sigma'." + f"Found '{layout.keys.names}' instead." + ) + for key, block in layout.items(): + o3_lambda, o3_sigma = int(key.values[0].item()), int( + key.values[1].item() + ) + if o3_sigma not in [-1, 1]: + raise ValueError( + "The layout ``TensorMap`` of a spherical tensor target should " + "have a key sample 'o3_sigma' that is either -1 or 1." + f"Found '{o3_sigma}' instead." + ) + if o3_lambda < 0: + raise ValueError( + "The layout ``TensorMap`` of a spherical tensor target should " + "have a key sample 'o3_lambda' that is non-negative." + f"Found '{o3_lambda}' instead." + ) + components = block.components + if len(components) != 1: + raise ValueError( + "The layout ``TensorMap`` of a spherical tensor target should " + "have a single component." + ) + if len(components[0]) != 2 * o3_lambda + 1: + raise ValueError( + "Each ``TensorBlock`` of a spherical tensor target should have " + "a component with 2*o3_lambda + 1 elements." + f"Found '{len(components[0])}' elements instead." + ) + if len(block.gradients_list()) > 0: + raise ValueError( + "Gradients of spherical tensor targets are not supported." + ) class DatasetInfo: @@ -180,7 +223,7 @@ class DatasetInfo: """ def __init__( - self, length_unit: str, atomic_types: List[int], targets: TargetInfoDict + self, length_unit: str, atomic_types: List[int], targets: Dict[str, TargetInfo] ): self.length_unit = length_unit if length_unit is not None else "" self._atomic_types = set(atomic_types) @@ -233,6 +276,14 @@ def update(self, other: "DatasetInfo") -> None: ) self.atomic_types = self.atomic_types + other.atomic_types + + intersecting_target_keys = self.targets.keys() & other.targets.keys() + for key in intersecting_target_keys: + if self.targets[key] != other.targets[key]: + raise ValueError( + f"Can't update DatasetInfo with different target information for " + f"target '{key}': {self.targets[key]} != {other.targets[key]}" + ) self.targets.update(other.targets) def union(self, other: "DatasetInfo") -> "DatasetInfo": diff --git a/src/metatrain/utils/data/extract_targets.py b/src/metatrain/utils/data/extract_targets.py deleted file mode 100644 index ee86b29d4..000000000 --- a/src/metatrain/utils/data/extract_targets.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import Dict, List, Union - -import torch - -from metatrain.utils.data import Dataset - -from .dataset import DatasetInfo, TargetInfo - - -def get_targets_dict( - datasets: List[Union[Dataset, torch.utils.data.Subset]], dataset_info: DatasetInfo -) -> Dict[str, TargetInfo]: - """ - This is a helper function that extracts all the possible targets and their - gradients from a list of datasets. - - :param datasets: A list of Datasets or Subsets. - :param dataset_info: A DatasetInfo object containing further - information about the dataset, namely the unit and quantity of the - targets. - - :returns: A dictionary mapping target names to ``TargetInfo`` objects. - - :raises ValueError: If the ``DatasetInfo`` object does not contain any of - the expected targets. - """ - - targets_dict = {} - for dataset in datasets: - targets = next(iter(dataset)) - targets = targets._asdict() - targets.pop("system") # system not needed - - # targets is now a dictionary of TensorMaps - for target_name, target_tmap in targets.items(): - if target_name not in dataset_info.targets.keys(): - raise ValueError( - f"Target {target_name} not found in the targets " - "specified in dataset_info." - ) - if target_name not in targets_dict: - targets_dict[target_name] = TargetInfo( - quantity=dataset_info.targets[target_name].quantity, - unit=dataset_info.targets[target_name].unit, - gradients=target_tmap.block(0).gradients_list(), - ) - - return targets_dict diff --git a/src/metatrain/utils/data/get_dataset.py b/src/metatrain/utils/data/get_dataset.py index 2f95263c5..502aea40f 100644 --- a/src/metatrain/utils/data/get_dataset.py +++ b/src/metatrain/utils/data/get_dataset.py @@ -1,12 +1,12 @@ -from typing import Tuple +from typing import Dict, Tuple from omegaconf import DictConfig -from .dataset import Dataset, TargetInfoDict +from .dataset import Dataset, TargetInfo from .readers import read_systems, read_targets -def get_dataset(options: DictConfig) -> Tuple[Dataset, TargetInfoDict]: +def get_dataset(options: DictConfig) -> Tuple[Dataset, Dict[str, TargetInfo]]: """ Gets a dataset given a configuration dictionary. @@ -18,7 +18,7 @@ def get_dataset(options: DictConfig) -> Tuple[Dataset, TargetInfoDict]: systems and targets in the dataset. :returns: A tuple containing a ``Dataset`` object and a - ``TargetInfoDict`` containing additional information (units, + ``Dict[str, TargetInfo]`` containing additional information (units, physical quantities, ...) on the targets in the dataset """ diff --git a/src/metatrain/utils/data/readers/readers.py b/src/metatrain/utils/data/readers/readers.py index b095195b1..7a6076143 100644 --- a/src/metatrain/utils/data/readers/readers.py +++ b/src/metatrain/utils/data/readers/readers.py @@ -8,7 +8,7 @@ from metatensor.torch.atomistic import System from omegaconf import DictConfig -from ..dataset import TargetInfo, TargetInfoDict +from ..dataset import TargetInfo logger = logging.getLogger(__name__) @@ -161,7 +161,7 @@ def read_virial( def read_targets( conf: DictConfig, -) -> Tuple[Dict[str, List[TensorMap]], TargetInfoDict]: +) -> Tuple[Dict[str, List[TensorMap]], Dict[str, TargetInfo]]: """Reading all target information from a fully expanded config. To get such a config you can use :func:`expand_dataset_config @@ -175,9 +175,8 @@ def read_targets( :param conf: config containing the keys for what should be read. :returns: Dictionary containing a list of TensorMaps for each target section in the - config as well as a :py:class:`TargetInfoDict - ` instance containing the metadata of the - targets. + config as well as a ``Dict[str, TargetInfo]`` object + containing the metadata of the targets. :raises ValueError: if the target name is not valid. Valid target names are those that either start with ``mtt::`` or those that are in the list of @@ -185,7 +184,7 @@ def read_targets( https://docs.metatensor.org/latest/atomistic/outputs.html) """ target_dictionary = {} - target_info_dictionary = TargetInfoDict() + target_info_dictionary = {} standard_outputs_list = ["energy"] for target_key, target in conf.items(): @@ -288,8 +287,39 @@ def read_targets( target_info_dictionary[target_key] = TargetInfo( quantity=target["quantity"], unit=target["unit"], - per_atom=False, # TODO: read this from the config - gradients=target_info_gradients, + layout=_empty_tensor_map_like(target_dictionary[target_key][0]), ) return target_dictionary, target_info_dictionary + + +def _empty_tensor_map_like(tensor_map: TensorMap) -> TensorMap: + new_keys = tensor_map.keys + new_blocks: List[TensorBlock] = [] + for block in tensor_map.blocks(): + new_block = _empty_tensor_block_like(block) + new_blocks.append(new_block) + return TensorMap(keys=new_keys, blocks=new_blocks) + + +def _empty_tensor_block_like(tensor_block: TensorBlock) -> TensorBlock: + new_block = TensorBlock( + values=torch.empty( + (0,) + tensor_block.values.shape[1:], + dtype=torch.float64, # metatensor can't serialize otherwise + device=tensor_block.values.device, + ), + samples=Labels( + names=tensor_block.samples.names, + values=torch.empty( + (0, tensor_block.samples.values.shape[1]), + dtype=tensor_block.samples.values.dtype, + device=tensor_block.samples.values.device, + ), + ), + components=tensor_block.components, + properties=tensor_block.properties, + ) + for gradient_name, gradient in tensor_block.gradients(): + new_block.add_gradient(gradient_name, _empty_tensor_block_like(gradient)) + return new_block diff --git a/src/metatrain/utils/evaluate_model.py b/src/metatrain/utils/evaluate_model.py index 5b4500a21..8a42ec3d0 100644 --- a/src/metatrain/utils/evaluate_model.py +++ b/src/metatrain/utils/evaluate_model.py @@ -12,7 +12,7 @@ register_autograd_neighbors, ) -from .data import TargetInfoDict +from .data import TargetInfo from .export import is_exported from .output_gradient import compute_gradient @@ -33,7 +33,7 @@ def evaluate_model( torch.jit._script.RecursiveScriptModule, ], systems: List[System], - targets: TargetInfoDict, + targets: Dict[str, TargetInfo], is_training: bool, check_consistency: bool = False, ) -> Dict[str, TensorMap]: @@ -234,7 +234,7 @@ def _get_model_outputs( torch.jit._script.RecursiveScriptModule, ], systems: List[System], - targets: TargetInfoDict, + targets: Dict[str, TargetInfo], check_consistency: bool, ) -> Dict[str, TensorMap]: if is_exported(model): diff --git a/src/metatrain/utils/llpr.py b/src/metatrain/utils/llpr.py index 164dd4267..6cfa42394 100644 --- a/src/metatrain/utils/llpr.py +++ b/src/metatrain/utils/llpr.py @@ -12,8 +12,7 @@ ) from torch.utils.data import DataLoader -from .data import DatasetInfo, TargetInfoDict, get_atomic_types -from .data.extract_targets import get_targets_dict +from .data import DatasetInfo, TargetInfo, get_atomic_types from .evaluate_model import evaluate_model from .per_atom import average_by_num_atoms @@ -261,7 +260,7 @@ class in ``metatrain``. def compute_covariance_as_pseudo_hessian( self, train_loader: DataLoader, - target_infos: TargetInfoDict, + target_infos: Dict[str, TargetInfo], loss_fn: Callable, parameters: List[torch.nn.Parameter], ) -> None: @@ -305,7 +304,7 @@ class in ``metatrain``. atomic_types=get_atomic_types(dataset), targets=target_infos, ) - train_targets = get_targets_dict([train_loader.dataset], dataset_info) + train_targets = dataset_info.targets device = self.covariance.device dtype = self.covariance.dtype for batch in train_loader: @@ -318,7 +317,7 @@ class in ``metatrain``. predictions = evaluate_model( self.model, systems, - TargetInfoDict(**{key: train_targets[key] for key in targets.keys()}), + {key: train_targets[key] for key in targets.keys()}, is_training=True, # keep the computational graph ) diff --git a/src/metatrain/utils/testing.py b/src/metatrain/utils/testing.py new file mode 100644 index 000000000..faedbdb00 --- /dev/null +++ b/src/metatrain/utils/testing.py @@ -0,0 +1,76 @@ +# This file contains some example TensorMap layouts that can be +# used for testing purposes. + +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + + +block = TensorBlock( + # float64: otherwise metatensor can't serialize + values=torch.empty(0, 1, dtype=torch.float64), + samples=Labels( + names=["system"], + values=torch.empty((0, 1), dtype=torch.int32), + ), + components=[], + properties=Labels.range("energy", 1), +) +energy_layout = TensorMap( + keys=Labels.single(), + blocks=[block], +) + +block_with_position_gradients = block.copy() +position_gradient_block = TensorBlock( + # float64: otherwise metatensor can't serialize + values=torch.empty(0, 3, 1, dtype=torch.float64), + samples=Labels( + names=["sample", "atom"], + values=torch.empty((0, 2), dtype=torch.int32), + ), + components=[ + Labels( + names=["xyz"], + values=torch.arange(3, dtype=torch.int32).reshape(-1, 1), + ), + ], + properties=Labels.range("energy", 1), +) +block_with_position_gradients.add_gradient("positions", position_gradient_block) +energy_force_layout = TensorMap( + keys=Labels.single(), + blocks=[block_with_position_gradients], +) + +block_with_position_and_strain_gradients = block_with_position_gradients.copy() +strain_gradient_block = TensorBlock( + # float64: otherwise metatensor can't serialize + values=torch.empty(0, 3, 3, 1, dtype=torch.float64), + samples=Labels( + names=["sample", "atom"], + values=torch.empty((0, 2), dtype=torch.int32), + ), + components=[ + Labels( + names=["xyz_1"], + values=torch.arange(3, dtype=torch.int32).reshape(-1, 1), + ), + Labels( + names=["xyz_2"], + values=torch.arange(3, dtype=torch.int32).reshape(-1, 1), + ), + ], + properties=Labels.range("energy", 1), +) +block_with_position_and_strain_gradients.add_gradient("strain", strain_gradient_block) +energy_force_stress_layout = TensorMap( + keys=Labels.single(), + blocks=[block_with_position_and_strain_gradients], +) + +block_with_strain_gradients = block.copy() +block_with_strain_gradients.add_gradient("strain", strain_gradient_block) +energy_stress_layout = TensorMap( + keys=Labels.single(), + blocks=[block_with_strain_gradients], +) diff --git a/tests/cli/test_eval_model.py b/tests/cli/test_eval_model.py index 6510deee9..6da30fb5c 100644 --- a/tests/cli/test_eval_model.py +++ b/tests/cli/test_eval_model.py @@ -11,6 +11,7 @@ from metatrain.cli.eval import eval_model from metatrain.experimental.soap_bpnn import __model__ from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.testing import energy_layout from . import EVAL_OPTIONS_PATH, MODEL_HYPERS, MODEL_PATH, RESOURCES_PATH @@ -84,9 +85,7 @@ def test_eval_export(monkeypatch, tmp_path, options): length_unit="angstrom", atomic_types={1, 6, 7, 8}, targets={ - "energy": TargetInfo( - quantity="energy", unit="eV", per_atom=False, gradients=[] - ) + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) }, ) model = __model__(model_hypers=MODEL_HYPERS, dataset_info=dataset_info) diff --git a/tests/cli/test_export_model.py b/tests/cli/test_export_model.py index a3e62e9d7..ca93a7b00 100644 --- a/tests/cli/test_export_model.py +++ b/tests/cli/test_export_model.py @@ -15,6 +15,7 @@ from metatrain.experimental.soap_bpnn import __model__ from metatrain.utils.architectures import find_all_architectures from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.testing import energy_layout from . import MODEL_HYPERS, RESOURCES_PATH @@ -28,9 +29,7 @@ def test_export(monkeypatch, tmp_path, path): length_unit="angstrom", atomic_types={1}, targets={ - "energy": TargetInfo( - quantity="energy", unit="eV", per_atom=False, gradients=[] - ) + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) }, ) model = __model__(model_hypers=MODEL_HYPERS, dataset_info=dataset_info) @@ -93,9 +92,7 @@ def test_reexport(monkeypatch, tmp_path): length_unit="angstrom", atomic_types={1, 6, 7, 8}, targets={ - "energy": TargetInfo( - quantity="energy", unit="eV", per_atom=False, gradients=[] - ) + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) }, ) model = __model__(model_hypers=MODEL_HYPERS, dataset_info=dataset_info) diff --git a/tests/cli/test_train_model.py b/tests/cli/test_train_model.py index 6fb54efd2..e6e398158 100644 --- a/tests/cli/test_train_model.py +++ b/tests/cli/test_train_model.py @@ -487,6 +487,7 @@ def test_continue_different_dataset(options, monkeypatch, tmp_path): options["training_set"]["systems"]["read_from"] = "ethanol_reduced_100.xyz" options["training_set"]["targets"]["energy"]["key"] = "energy" + options["training_set"]["targets"]["energy"]["forces"] = False train_model(options, continue_from=MODEL_PATH_64_BIT) diff --git a/tests/utils/data/test_dataset.py b/tests/utils/data/test_dataset.py index 20100aa3d..9c76c5a7b 100644 --- a/tests/utils/data/test_dataset.py +++ b/tests/utils/data/test_dataset.py @@ -2,13 +2,13 @@ import pytest import torch +from metatensor.torch import Labels, TensorBlock, TensorMap from omegaconf import OmegaConf from metatrain.utils.data import ( Dataset, DatasetInfo, TargetInfo, - TargetInfoDict, check_datasets, collate_fn, get_all_targets, @@ -22,81 +22,159 @@ RESOURCES_PATH = Path(__file__).parents[2] / "resources" -def test_target_info_default(): - target_info = TargetInfo(quantity="energy", unit="kcal/mol") +@pytest.fixture +def layout_scalar(): + return TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.empty(0, 1), + samples=Labels( + names=["system"], + values=torch.empty((0, 1), dtype=torch.int32), + ), + components=[], + properties=Labels.range("energy", 1), + ) + ], + ) + + +@pytest.fixture +def layout_spherical(): + return TensorMap( + keys=Labels( + names=["o3_lambda", "o3_sigma"], + values=torch.tensor([[0, 1], [2, 1]]), + ), + blocks=[ + TensorBlock( + values=torch.empty(0, 1, 1), + samples=Labels( + names=["system"], + values=torch.empty((0, 1), dtype=torch.int32), + ), + components=[ + Labels( + names=["o3_mu"], + values=torch.arange(0, 1, dtype=torch.int32).reshape(-1, 1), + ), + ], + properties=Labels.single(), + ), + TensorBlock( + values=torch.empty(0, 5, 1), + samples=Labels( + names=["system"], + values=torch.empty((0, 1), dtype=torch.int32), + ), + components=[ + Labels( + names=["o3_mu"], + values=torch.arange(-2, 3, dtype=torch.int32).reshape(-1, 1), + ), + ], + properties=Labels.single(), + ), + ], + ) + + +@pytest.fixture +def layout_cartesian(): + return TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.empty(0, 3, 3, 1), + samples=Labels( + names=["system"], + values=torch.empty((0, 1), dtype=torch.int32), + ), + components=[ + Labels( + names=["xyz_1"], + values=torch.arange(0, 3, dtype=torch.int32).reshape(-1, 1), + ), + Labels( + names=["xyz_2"], + values=torch.arange(0, 3, dtype=torch.int32).reshape(-1, 1), + ), + ], + properties=Labels.single(), + ), + ], + ) + + +def test_target_info_scalar(layout_scalar): + target_info = TargetInfo(quantity="energy", unit="kcal/mol", layout=layout_scalar) assert target_info.quantity == "energy" assert target_info.unit == "kcal/mol" - assert target_info.per_atom is False assert target_info.gradients == [] + assert not target_info.per_atom - expected = ( - "TargetInfo(quantity='energy', unit='kcal/mol', per_atom=False, gradients=[])" - ) - assert target_info.__repr__() == expected + expected_start = "TargetInfo(quantity='energy', unit='kcal/mol'" + assert target_info.__repr__()[: len(expected_start)] == expected_start -def test_target_info_gradients(): +def test_target_info_spherical(layout_spherical): target_info = TargetInfo( - quantity="energy", - unit="kcal/mol", - per_atom=True, - gradients=["positions", "positions"], + quantity="mtt::spherical", unit="kcal/mol", layout=layout_spherical ) - assert target_info.quantity == "energy" + assert target_info.quantity == "mtt::spherical" assert target_info.unit == "kcal/mol" - assert target_info.per_atom is True - assert target_info.gradients == ["positions"] + assert target_info.gradients == [] + assert not target_info.per_atom - expected = ( - "TargetInfo(quantity='energy', unit='kcal/mol', per_atom=True, " - "gradients=['positions'])" - ) - assert target_info.__repr__() == expected + expected_start = "TargetInfo(quantity='mtt::spherical', unit='kcal/mol'" + assert target_info.__repr__()[: len(expected_start)] == expected_start -def test_list_gradients(): - info1 = TargetInfo(quantity="energy", unit="eV") +def test_target_info_cartesian(layout_cartesian): + target_info = TargetInfo( + quantity="mtt::cartesian", unit="kcal/mol", layout=layout_cartesian + ) - info1.gradients = ["positions"] - assert info1.gradients == ["positions"] + assert target_info.quantity == "mtt::cartesian" + assert target_info.unit == "kcal/mol" + assert target_info.gradients == [] + assert not target_info.per_atom - info1.gradients += ["strain"] - assert info1.gradients == ["positions", "strain"] + expected_start = "TargetInfo(quantity='mtt::cartesian', unit='kcal/mol'" + assert target_info.__repr__()[: len(expected_start)] == expected_start -def test_unit_none_conversion(): - info = TargetInfo(quantity="energy", unit=None) +def test_unit_none_conversion(layout_scalar): + info = TargetInfo(quantity="energy", unit=None, layout=layout_scalar) assert info.unit == "" -def test_length_unit_none_conversion(): +def test_length_unit_none_conversion(layout_scalar): dataset_info = DatasetInfo( length_unit=None, atomic_types=[1, 2, 3], - targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="kcal/mol")), + targets={ + "energy": TargetInfo( + quantity="energy", unit="kcal/mol", layout=layout_scalar + ) + }, ) assert dataset_info.length_unit == "" -def test_target_info_copy(): - info = TargetInfo(quantity="energy", unit="eV", gradients=["positions"]) - copy = info.copy() - assert copy == info - assert copy is not info - - -def test_target_info_eq(): - info1 = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) - info2 = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) +def test_target_info_eq(layout_scalar): + info1 = TargetInfo(quantity="energy", unit="kcal/mol", layout=layout_scalar) + info2 = TargetInfo(quantity="energy", unit="eV", layout=layout_scalar) assert info1 == info1 assert info1 != info2 -def test_target_info_eq_error(): - info = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) +def test_target_info_eq_error(layout_scalar): + info = TargetInfo(quantity="energy", unit="eV", layout=layout_scalar) match = ( "Comparison between a TargetInfo instance and a list instance is not " @@ -106,152 +184,15 @@ def test_target_info_eq_error(): _ = info == [1, 2, 3] -def test_target_info_update(): - info1 = TargetInfo(quantity="energy", unit="eV", gradients=["strain", "aaa"]) - info2 = TargetInfo(quantity="energy", unit="eV", gradients=["positions"]) - info1.update(info2) - assert info1.gradients == ["aaa", "positions", "strain"] - - -def test_target_info_union(): - info1 = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) - info2 = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) - info_new = info1.union(info2) - assert isinstance(info_new, TargetInfo) - assert info_new.gradients == ["position", "strain"] - - -def test_target_info_update_non_matching_quantity(): - info1 = TargetInfo(quantity="energy", unit="eV") - info2 = TargetInfo(quantity="force", unit="eV") - match = r"Can't update TargetInfo with a different `quantity`: \(energy != force\)" - with pytest.raises(ValueError, match=match): - info1.update(info2) - - -def test_target_info_update_non_matching_unit(): - info1 = TargetInfo(quantity="energy", unit="eV") - info2 = TargetInfo(quantity="energy", unit="kcal") - match = r"Can't update TargetInfo with a different `unit`: \(eV != kcal\)" - with pytest.raises(ValueError, match=match): - info1.update(info2) - - -def test_target_info_update_non_matching_per_atom(): - info1 = TargetInfo(quantity="energy", unit="eV", per_atom=True) - info2 = TargetInfo(quantity="energy", unit="eV", per_atom=False) - match = "Can't update TargetInfo with a different `per_atom` property: " - with pytest.raises(ValueError, match=match): - info1.update(info2) - - -def test_target_info_dict_setitem_new_entry(): - tid = TargetInfoDict() - info = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) - tid["energy"] = info - assert tid["energy"] == info - - -def test_target_info_dict_setitem_update_entry(): - tid = TargetInfoDict() - info1 = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) - info2 = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) - tid["energy"] = info1 - tid["energy"] = info2 - assert tid["energy"].gradients == ["position", "strain"] - - -def test_target_info_dict_setitem_value_error(): - tid = TargetInfoDict() - with pytest.raises(ValueError, match="value to set is not a `TargetInfo` instance"): - tid["energy"] = "not a TargetInfo" - - -def test_target_info_dict_union(): - tid1 = TargetInfoDict() - tid1["energy"] = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) - - tid2 = TargetInfoDict() - tid2["myenergy"] = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) - - merged = tid1.union(tid2) - assert merged["energy"] == tid1["energy"] - assert merged["myenergy"] == tid2["myenergy"] - - -def test_target_info_dict_merge_error(): - tid1 = TargetInfoDict() - tid1["energy"] = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) - - tid2 = TargetInfoDict() - tid2["energy"] = TargetInfo( - quantity="energy", unit="kcal/mol", gradients=["strain"] - ) - - match = r"Can't update TargetInfo with a different `unit`: \(eV != kcal/mol\)" - with pytest.raises(ValueError, match=match): - tid1.union(tid2) - - -def test_target_info_dict_intersection(): - tid1 = TargetInfoDict() - tid1["energy"] = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) - tid1["myenergy"] = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) - - tid2 = TargetInfoDict() - tid2["myenergy"] = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) - - intersection = tid1.intersection(tid2) - assert len(intersection) == 1 - assert intersection["myenergy"] == tid1["myenergy"] - - # Test `&` operator - intersection_and = tid1 & tid2 - assert intersection_and == intersection - - -def test_target_info_dict_intersection_error(): - tid1 = TargetInfoDict() - tid1["energy"] = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) - tid1["myenergy"] = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) - - tid2 = TargetInfoDict() - tid2["myenergy"] = TargetInfo( - quantity="energy", unit="kcal/mol", gradients=["strain"] - ) - - match = ( - r"Intersected items with the same key are not the same. Intersected " - r"keys are myenergy" - ) - - with pytest.raises(ValueError, match=match): - tid1.intersection(tid2) - - -def test_target_info_dict_difference(): - tid1 = TargetInfoDict() - tid1["energy"] = TargetInfo(quantity="energy", unit="eV", gradients=["position"]) - tid1["myenergy"] = TargetInfo(quantity="energy", unit="eV", gradients=["strain"]) - - tid2 = TargetInfoDict() - tid2["myenergy"] = TargetInfo( - quantity="energy", unit="kcal/mol", gradients=["strain"] - ) - - difference = tid1.difference(tid2) - assert len(difference) == 1 - assert difference["energy"] == tid1["energy"] - - difference_sub = tid1 - tid2 - assert difference_sub == difference - - -def test_dataset_info(): +def test_dataset_info(layout_scalar): """Tests the DatasetInfo class.""" - targets = TargetInfoDict(energy=TargetInfo(quantity="energy", unit="kcal/mol")) - targets["mtt::U0"] = TargetInfo(quantity="energy", unit="kcal/mol") + targets = { + "energy": TargetInfo(quantity="energy", unit="kcal/mol", layout=layout_scalar) + } + targets["mtt::U0"] = TargetInfo( + quantity="energy", unit="kcal/mol", layout=layout_scalar + ) dataset_info = DatasetInfo( length_unit="angstrom", atomic_types=[3, 1, 2], targets=targets @@ -271,9 +212,13 @@ def test_dataset_info(): assert dataset_info.__repr__() == expected -def test_set_atomic_types(): - targets = TargetInfoDict(energy=TargetInfo(quantity="energy", unit="kcal/mol")) - targets["mtt::U0"] = TargetInfo(quantity="energy", unit="kcal/mol") +def test_set_atomic_types(layout_scalar): + targets = { + "energy": TargetInfo(quantity="energy", unit="kcal/mol", layout=layout_scalar) + } + targets["mtt::U0"] = TargetInfo( + quantity="energy", unit="kcal/mol", layout=layout_scalar + ) dataset_info = DatasetInfo( length_unit="angstrom", atomic_types=[3, 1, 2], targets=targets @@ -286,10 +231,12 @@ def test_set_atomic_types(): assert dataset_info.atomic_types == [1, 4, 5, 7] -def test_dataset_info_copy(): - targets = TargetInfoDict() - targets["energy"] = TargetInfo(quantity="energy", unit="eV") - targets["forces"] = TargetInfo(quantity="mtt::forces", unit="eV/Angstrom") +def test_dataset_info_copy(layout_scalar, layout_cartesian): + targets = {} + targets["energy"] = TargetInfo(quantity="energy", unit="eV", layout=layout_scalar) + targets["mtt::my-target"] = TargetInfo( + quantity="mtt::my-target", unit="eV/Angstrom", layout=layout_cartesian + ) info = DatasetInfo(length_unit="angstrom", atomic_types=[1, 6], targets=targets) copy = info.copy() @@ -298,31 +245,35 @@ def test_dataset_info_copy(): assert copy is not info -def test_dataset_info_update(): - targets = TargetInfoDict() - targets["energy"] = TargetInfo(quantity="energy", unit="eV") +def test_dataset_info_update(layout_scalar, layout_spherical): + targets = {} + targets["energy"] = TargetInfo(quantity="energy", unit="eV", layout=layout_scalar) info = DatasetInfo(length_unit="angstrom", atomic_types=[1, 6], targets=targets) targets2 = targets.copy() - targets2["forces"] = TargetInfo(quantity="mtt::forces", unit="eV/Angstrom") + targets2["mtt::my-target"] = TargetInfo( + quantity="mtt::my-target", unit="eV/Angstrom", layout=layout_spherical + ) info2 = DatasetInfo(length_unit="angstrom", atomic_types=[8], targets=targets2) info.update(info2) assert info.atomic_types == [1, 6, 8] assert info.targets["energy"] == targets["energy"] - assert info.targets["forces"] == targets2["forces"] + assert info.targets["mtt::my-target"] == targets2["mtt::my-target"] -def test_dataset_info_update_non_matching_length_unit(): - targets = TargetInfoDict() - targets["energy"] = TargetInfo(quantity="energy", unit="eV") +def test_dataset_info_update_non_matching_length_unit(layout_scalar): + targets = {} + targets["energy"] = TargetInfo(quantity="energy", unit="eV", layout=layout_scalar) info = DatasetInfo(length_unit="angstrom", atomic_types=[1, 6], targets=targets) targets2 = targets.copy() - targets2["forces"] = TargetInfo(quantity="mtt::forces", unit="eV/Angstrom") + targets2["mtt::my-target"] = TargetInfo( + quantity="mtt::my-target", unit="eV/Angstrom", layout=layout_scalar + ) info2 = DatasetInfo(length_unit="nanometer", atomic_types=[8], targets=targets2) @@ -335,23 +286,25 @@ def test_dataset_info_update_non_matching_length_unit(): info.update(info2) -def test_dataset_info_eq(): - targets = TargetInfoDict() - targets["energy"] = TargetInfo(quantity="energy", unit="eV") +def test_dataset_info_eq(layout_scalar): + targets = {} + targets["energy"] = TargetInfo(quantity="energy", unit="eV", layout=layout_scalar) info = DatasetInfo(length_unit="angstrom", atomic_types=[1, 6], targets=targets) targets2 = targets.copy() - targets2["forces"] = TargetInfo(quantity="mtt::forces", unit="eV/Angstrom") + targets2["my-target"] = TargetInfo( + quantity="mtt::my-target", unit="eV/Angstrom", layout=layout_scalar + ) info2 = DatasetInfo(length_unit="nanometer", atomic_types=[8], targets=targets2) assert info == info assert info != info2 -def test_dataset_info_eq_error(): - targets = TargetInfoDict() - targets["energy"] = TargetInfo(quantity="energy", unit="eV") +def test_dataset_info_eq_error(layout_scalar): + targets = {} + targets["energy"] = TargetInfo(quantity="energy", unit="eV", layout=layout_scalar) info = DatasetInfo(length_unit="angstrom", atomic_types=[1, 6], targets=targets) @@ -363,31 +316,39 @@ def test_dataset_info_eq_error(): _ = info == [1, 2, 3] -def test_dataset_info_update_different_target_info(): - targets = TargetInfoDict() - targets["energy"] = TargetInfo(quantity="energy", unit="eV") +def test_dataset_info_update_different_target_info(layout_scalar): + targets = {} + targets["energy"] = TargetInfo(quantity="energy", unit="eV", layout=layout_scalar) info = DatasetInfo(length_unit="angstrom", atomic_types=[1, 6], targets=targets) - targets2 = TargetInfoDict() - targets2["energy"] = TargetInfo(quantity="energy", unit="eV/Angstrom") + targets2 = {} + targets2["energy"] = TargetInfo( + quantity="energy", unit="eV/Angstrom", layout=layout_scalar + ) info2 = DatasetInfo(length_unit="angstrom", atomic_types=[8], targets=targets2) - match = r"Can't update TargetInfo with a different `unit`: \(eV != eV/Angstrom\)" + match = ( + "Can't update DatasetInfo with different target information for target 'energy'" + ) with pytest.raises(ValueError, match=match): info.update(info2) -def test_dataset_info_union(): +def test_dataset_info_union(layout_scalar, layout_cartesian): """Tests the union method.""" - targets = TargetInfoDict() - targets["energy"] = TargetInfo(quantity="energy", unit="eV") - targets["forces"] = TargetInfo(quantity="mtt::forces", unit="eV/Angstrom") + targets = {} + targets["energy"] = TargetInfo(quantity="energy", unit="eV", layout=layout_scalar) + targets["forces"] = TargetInfo( + quantity="mtt::forces", unit="eV/Angstrom", layout=layout_scalar + ) info = DatasetInfo(length_unit="angstrom", atomic_types=[1, 6], targets=targets) other_targets = targets.copy() - other_targets["mtt::stress"] = TargetInfo(quantity="mtt::stress", unit="GPa") + other_targets["mtt::stress"] = TargetInfo( + quantity="mtt::stress", unit="GPa", layout=layout_cartesian + ) other_info = DatasetInfo( length_unit="angstrom", atomic_types=[1], targets=other_targets @@ -606,7 +567,7 @@ def test_collate_fn(): assert isinstance(batch[1], dict) -def test_get_stats(): +def test_get_stats(layout_scalar): """Tests the get_stats method of Dataset and Subset.""" systems = read_systems(RESOURCES_PATH / "qm9_reduced_100.xyz") @@ -644,8 +605,8 @@ def test_get_stats(): length_unit="angstrom", atomic_types=[1, 6], targets={ - "mtt::U0": TargetInfo(quantity="energy", unit="eV"), - "energy": TargetInfo(quantity="energy", unit="eV"), + "mtt::U0": TargetInfo(quantity="energy", unit="eV", layout=layout_scalar), + "energy": TargetInfo(quantity="energy", unit="eV", layout=layout_scalar), }, ) diff --git a/tests/utils/data/test_readers.py b/tests/utils/data/test_readers.py index 92b5e656b..5876d8e70 100644 --- a/tests/utils/data/test_readers.py +++ b/tests/utils/data/test_readers.py @@ -11,7 +11,7 @@ from omegaconf import OmegaConf from test_targets_ase import ase_system, ase_systems -from metatrain.utils.data.dataset import TargetInfo, TargetInfoDict +from metatrain.utils.data.dataset import TargetInfo from metatrain.utils.data.readers import ( read_energy, read_forces, @@ -178,7 +178,7 @@ def test_read_targets(stress_dict, virial_dict, monkeypatch, tmp_path, caplog): assert any(["Forces found" in rec.message for rec in caplog.records]) assert type(result) is dict - assert type(target_info_dict) is TargetInfoDict + assert type(target_info_dict) is dict if stress_dict: assert any(["Stress found" in rec.message for rec in caplog.records]) diff --git a/tests/utils/test_additive.py b/tests/utils/test_additive.py index 3daf3c980..b3472d2f3 100644 --- a/tests/utils/test_additive.py +++ b/tests/utils/test_additive.py @@ -8,12 +8,13 @@ from omegaconf import OmegaConf from metatrain.utils.additive import ZBL, CompositionModel, remove_additive -from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo, TargetInfoDict +from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo from metatrain.utils.data.readers import read_systems, read_targets from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) +from metatrain.utils.testing import energy_layout RESOURCES_PATH = Path(__file__).parents[1] / "resources" @@ -82,14 +83,12 @@ def test_composition_model_train(): dataset_info=DatasetInfo( length_unit="angstrom", atomic_types=[1, 8], - targets=TargetInfoDict( - { - "energy": TargetInfo( - quantity="energy", - per_atom=False, - ) - } - ), + targets={ + "energy": TargetInfo( + quantity="energy", + layout=energy_layout, + ) + }, ), ) @@ -211,14 +210,12 @@ def test_composition_model_torchscript(tmpdir): dataset_info=DatasetInfo( length_unit="angstrom", atomic_types=[1, 8], - targets=TargetInfoDict( - { - "energy": TargetInfo( - quantity="energy", - per_atom=False, - ) - } - ), + targets={ + "energy": TargetInfo( + quantity="energy", + layout=energy_layout, + ) + }, ), ) composition_model = torch.jit.script(composition_model) @@ -342,14 +339,12 @@ def test_composition_model_missing_types(): dataset_info=DatasetInfo( length_unit="angstrom", atomic_types=[1], - targets=TargetInfoDict( - { - "energy": TargetInfo( - quantity="energy", - per_atom=False, - ) - } - ), + targets={ + "energy": TargetInfo( + quantity="energy", + layout=energy_layout, + ) + }, ), ) with pytest.raises( @@ -363,14 +358,12 @@ def test_composition_model_missing_types(): dataset_info=DatasetInfo( length_unit="angstrom", atomic_types=[1, 8, 100], - targets=TargetInfoDict( - { - "energy": TargetInfo( - quantity="energy", - per_atom=False, - ) - } - ), + targets={ + "energy": TargetInfo( + quantity="energy", + layout=energy_layout, + ) + }, ), ) with pytest.warns( @@ -394,14 +387,12 @@ def test_composition_model_wrong_target(): dataset_info=DatasetInfo( length_unit="angstrom", atomic_types=[1], - targets=TargetInfoDict( - { - "energy": TargetInfo( - quantity="FOO", - per_atom=False, - ) - } - ), + targets={ + "energy": TargetInfo( + quantity="FOO", + layout=energy_layout, + ) + }, ), ) diff --git a/tests/utils/test_evaluate_model.py b/tests/utils/test_evaluate_model.py index e2bd81eca..9c147378d 100644 --- a/tests/utils/test_evaluate_model.py +++ b/tests/utils/test_evaluate_model.py @@ -10,6 +10,7 @@ get_requested_neighbor_lists, get_system_with_neighbor_lists, ) +from metatrain.utils.testing import energy_force_stress_layout from . import MODEL_HYPERS, RESOURCES_PATH @@ -27,7 +28,7 @@ def test_evaluate_model(training, exported): targets = { "energy": TargetInfo( - quantity="energy", unit="eV", gradients=["positions", "strain"] + quantity="energy", unit="eV", layout=energy_force_stress_layout ) } diff --git a/tests/utils/test_export.py b/tests/utils/test_export.py index 439a279f4..a0dcc1784 100644 --- a/tests/utils/test_export.py +++ b/tests/utils/test_export.py @@ -7,6 +7,7 @@ from metatrain.experimental.soap_bpnn import __model__ from metatrain.utils.data import DatasetInfo, TargetInfo from metatrain.utils.export import export, is_exported +from metatrain.utils.testing import energy_layout from . import MODEL_HYPERS, RESOURCES_PATH @@ -18,7 +19,9 @@ def test_export(tmp_path): dataset_info = DatasetInfo( length_unit="angstrom", atomic_types={1}, - targets={"energy": TargetInfo(quantity="energy", unit="eV")}, + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = __model__(model_hypers=MODEL_HYPERS, dataset_info=dataset_info) @@ -48,7 +51,9 @@ def test_reexport(monkeypatch, tmp_path): dataset_info = DatasetInfo( length_unit="angstrom", atomic_types={1}, - targets={"energy": TargetInfo(quantity="energy", unit="eV")}, + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = __model__(model_hypers=MODEL_HYPERS, dataset_info=dataset_info) @@ -79,7 +84,9 @@ def test_length_units_warning(): dataset_info = DatasetInfo( length_unit="angstrom", atomic_types={1}, - targets={"energy": TargetInfo(quantity="energy", unit="eV")}, + targets={ + "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + }, ) model = __model__(model_hypers=MODEL_HYPERS, dataset_info=dataset_info) @@ -100,7 +107,7 @@ def test_units_warning(): dataset_info = DatasetInfo( length_unit="angstrom", atomic_types={1}, - targets={"mtt::output": TargetInfo(quantity="energy")}, + targets={"mtt::output": TargetInfo(quantity="energy", layout=energy_layout)}, ) model = __model__(model_hypers=MODEL_HYPERS, dataset_info=dataset_info) diff --git a/tests/utils/test_external_naming.py b/tests/utils/test_external_naming.py index 666d1c890..0e3db1371 100644 --- a/tests/utils/test_external_naming.py +++ b/tests/utils/test_external_naming.py @@ -1,14 +1,15 @@ from metatrain.utils.data.dataset import TargetInfo from metatrain.utils.external_naming import to_external_name, to_internal_name +from metatrain.utils.testing import energy_layout def test_to_external_name(): """Tests the to_external_name function.""" quantities = { - "energy": TargetInfo(quantity="energy"), - "mtt::free_energy": TargetInfo(quantity="energy"), - "mtt::foo": TargetInfo(quantity="bar"), + "energy": TargetInfo(quantity="energy", layout=energy_layout), + "mtt::free_energy": TargetInfo(quantity="energy", layout=energy_layout), + "mtt::foo": TargetInfo(quantity="bar", layout=energy_layout), } assert to_external_name("energy_positions_gradients", quantities) == "forces" diff --git a/tests/utils/test_output_gradient.py b/tests/utils/test_output_gradient.py index 513d43930..75aaf76e7 100644 --- a/tests/utils/test_output_gradient.py +++ b/tests/utils/test_output_gradient.py @@ -6,6 +6,11 @@ from metatrain.experimental.soap_bpnn import __model__ from metatrain.utils.data import DatasetInfo, TargetInfo, read_systems from metatrain.utils.output_gradient import compute_gradient +from metatrain.utils.testing import ( + energy_force_layout, + energy_force_stress_layout, + energy_stress_layout, +) from . import MODEL_HYPERS, RESOURCES_PATH @@ -19,7 +24,7 @@ def test_forces(is_training): atomic_types={1, 6, 7, 8}, targets={ "energy": TargetInfo( - quantity="energy", unit="eV", per_atom=False, gradients=["positions"] + quantity="energy", unit="eV", layout=energy_force_layout ) }, ) @@ -78,7 +83,7 @@ def test_virial(is_training): atomic_types={6}, targets={ "energy": TargetInfo( - quantity="energy", unit="eV", per_atom=False, gradients=["strain"] + quantity="energy", unit="eV", layout=energy_stress_layout ) }, ) @@ -151,8 +156,7 @@ def test_both(is_training): "energy": TargetInfo( quantity="energy", unit="eV", - per_atom=False, - gradients=["positions", "strain"], + layout=energy_force_stress_layout, ) }, )