From f5b3527f280056e22f276a4d01d4f1aa34c7c296 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 29 Aug 2024 08:47:48 +0200 Subject: [PATCH 1/3] Update metatensor --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2c1d7c74..7be2777b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,9 +10,9 @@ authors = [{name = "metatrain developers"}] dependencies = [ "ase < 3.23.0", - "metatensor-learn==0.2.2", - "metatensor-operations==0.2.1", - "metatensor-torch==0.5.3", + "metatensor-learn==0.2.3", + "metatensor-operations==0.2.3", + "metatensor-torch==0.5.4", "jsonschema", "omegaconf", "python-hostlist", From e297fad279423b63a82f9efabca247253a301252 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 29 Aug 2024 12:52:13 +0200 Subject: [PATCH 2/3] `._module` -> `.module` --- src/metatrain/experimental/pet/tests/test_pet_compatibility.py | 2 +- src/metatrain/utils/llpr.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/metatrain/experimental/pet/tests/test_pet_compatibility.py b/src/metatrain/experimental/pet/tests/test_pet_compatibility.py index 16f7955a..852750be 100644 --- a/src/metatrain/experimental/pet/tests/test_pet_compatibility.py +++ b/src/metatrain/experimental/pet/tests/test_pet_compatibility.py @@ -159,7 +159,7 @@ def test_predictions_compatibility(cutoff): "neighbors_pos": batch.neighbors_pos, } - pet = model._module.pet + pet = model.module.pet pet_prediction = pet.forward(batch_dict) diff --git a/src/metatrain/utils/llpr.py b/src/metatrain/utils/llpr.py index 20eb68df..7f23fda9 100644 --- a/src/metatrain/utils/llpr.py +++ b/src/metatrain/utils/llpr.py @@ -31,7 +31,7 @@ def __init__( super().__init__() self.model = model - self.ll_feat_size = self.model._module.last_layer_feature_size + self.ll_feat_size = self.model.module.last_layer_feature_size # update capabilities: now we have additional outputs for the uncertainty old_capabilities = self.model.capabilities() From b6caa7b57e93e67aba54fdfbe85b94d838a5e347 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 29 Aug 2024 13:28:09 +0200 Subject: [PATCH 3/3] Update dataset --- examples/programmatic/llpr/llpr.py | 2 +- src/metatrain/cli/eval.py | 2 +- src/metatrain/cli/train.py | 16 ++- .../alchemical_model/tests/test_regression.py | 2 +- .../alchemical_model/utils/normalize.py | 10 +- .../experimental/gap/tests/test_errors.py | 4 +- .../experimental/gap/tests/test_regression.py | 4 +- .../gap/tests/test_torchscript.py | 2 +- .../soap_bpnn/tests/test_continue.py | 2 +- .../soap_bpnn/tests/test_regression.py | 2 +- src/metatrain/utils/data/__init__.py | 2 +- src/metatrain/utils/data/dataset.py | 102 ++---------------- src/metatrain/utils/data/extract_targets.py | 1 + src/metatrain/utils/data/get_dataset.py | 2 +- tests/utils/data/test_combine_dataloaders.py | 8 +- tests/utils/data/test_dataset.py | 37 ++++--- tests/utils/data/test_get_dataset.py | 4 +- tests/utils/test_composition.py | 2 +- tests/utils/test_llpr.py | 2 +- 19 files changed, 69 insertions(+), 137 deletions(-) diff --git a/examples/programmatic/llpr/llpr.py b/examples/programmatic/llpr/llpr.py index df79773d..8db135c8 100644 --- a/examples/programmatic/llpr/llpr.py +++ b/examples/programmatic/llpr/llpr.py @@ -72,7 +72,7 @@ get_system_with_neighbor_lists(system, requested_neighbor_lists) for system in qm9_systems ] -dataset = Dataset({"system": qm9_systems, **targets}) +dataset = Dataset.from_dict({"system": qm9_systems, **targets}) # We also load a single ethanol molecule on which we will compute properties. # This system is loaded without targets, as we are only interested in the LPR diff --git a/src/metatrain/cli/eval.py b/src/metatrain/cli/eval.py index 4df6572a..93adb162 100644 --- a/src/metatrain/cli/eval.py +++ b/src/metatrain/cli/eval.py @@ -296,7 +296,7 @@ def eval_model( gradients=gradients, ) - eval_dataset = Dataset({"system": eval_systems, **eval_targets}) + eval_dataset = Dataset.from_dict({"system": eval_systems, **eval_targets}) # Evaluate the model try: diff --git a/src/metatrain/cli/train.py b/src/metatrain/cli/train.py index e567b5c8..ffc880f7 100644 --- a/src/metatrain/cli/train.py +++ b/src/metatrain/cli/train.py @@ -15,7 +15,13 @@ from .. import PACKAGE_ROOT from ..utils.architectures import check_architecture_options, get_default_hypers -from ..utils.data import DatasetInfo, TargetInfoDict, get_atomic_types, get_dataset +from ..utils.data import ( + DatasetInfo, + TargetInfoDict, + get_atomic_types, + get_dataset, + get_stats, +) from ..utils.data.dataset import _train_test_random_split from ..utils.devices import pick_devices from ..utils.distributed.logging import is_main_process @@ -290,7 +296,7 @@ def train_model( else: index = f" {i}" logger.info( - f"Training dataset{index}:\n {train_dataset.get_stats(dataset_info)}" + f"Training dataset{index}:\n {get_stats(train_dataset, dataset_info)}" ) for i, val_dataset in enumerate(val_datasets): @@ -299,7 +305,7 @@ def train_model( else: index = f" {i}" logger.info( - f"Validation dataset{index}:\n {val_dataset.get_stats(dataset_info)}" + f"Validation dataset{index}:\n {get_stats(val_dataset, dataset_info)}" ) for i, test_dataset in enumerate(test_datasets): @@ -307,7 +313,9 @@ def train_model( index = "" else: index = f" {i}" - logger.info(f"Test dataset{index}:\n {test_dataset.get_stats(dataset_info)}") + logger.info( + f"Test dataset{index}:\n {get_stats(test_dataset, dataset_info)}" + ) ########################### # SAVE EXPANDED OPTIONS ### diff --git a/src/metatrain/experimental/alchemical_model/tests/test_regression.py b/src/metatrain/experimental/alchemical_model/tests/test_regression.py index fe8dec96..4dbc6ed0 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_regression.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_regression.py @@ -92,7 +92,7 @@ def test_regression_train(): } } targets, target_info_dict = read_targets(OmegaConf.create(conf)) - dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]}) + dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) hypers = DEFAULT_HYPERS.copy() diff --git a/src/metatrain/experimental/alchemical_model/utils/normalize.py b/src/metatrain/experimental/alchemical_model/utils/normalize.py index addc8549..494aa038 100644 --- a/src/metatrain/experimental/alchemical_model/utils/normalize.py +++ b/src/metatrain/experimental/alchemical_model/utils/normalize.py @@ -18,10 +18,10 @@ def get_average_number_of_atoms( """ average_number_of_atoms = [] for dataset in datasets: - dtype = dataset[0]["system"].positions.dtype + dtype = dataset[0].system.positions.dtype num_atoms = [] for i in range(len(dataset)): - system = dataset[i]["system"] + system = dataset[i].system num_atoms.append(len(system)) average_number_of_atoms.append(torch.mean(torch.tensor(num_atoms, dtype=dtype))) return torch.tensor(average_number_of_atoms) @@ -39,9 +39,9 @@ def get_average_number_of_neighbors( average_number_of_neighbors = [] for dataset in datasets: num_neighbor = [] - dtype = dataset[0]["system"].positions.dtype + dtype = dataset[0].system.positions.dtype for i in range(len(dataset)): - system = dataset[i]["system"] + system = dataset[i].system known_neighbor_lists = system.known_neighbor_lists() if len(known_neighbor_lists) == 0: raise ValueError(f"system {system} does not have a neighbor list") @@ -94,4 +94,4 @@ def remove_composition_from_dataset( new_systems.append(system) new_properties.append(property) - return Dataset({"system": new_systems, property_name: new_properties}) + return Dataset.from_dict({"system": new_systems, property_name: new_properties}) diff --git a/src/metatrain/experimental/gap/tests/test_errors.py b/src/metatrain/experimental/gap/tests/test_errors.py index 7cc1958f..24e5c03b 100644 --- a/src/metatrain/experimental/gap/tests/test_errors.py +++ b/src/metatrain/experimental/gap/tests/test_errors.py @@ -54,7 +54,9 @@ def test_ethanol_regression_train_and_invariance(): } targets, _ = read_targets(OmegaConf.create(conf)) - dataset = Dataset({"system": systems[:2], "energy": targets["energy"][:2]}) + dataset = Dataset.from_dict( + {"system": systems[:2], "energy": targets["energy"][:2]} + ) hypers = copy.deepcopy(DEFAULT_HYPERS) hypers["model"]["krr"]["num_sparse_points"] = 30 diff --git a/src/metatrain/experimental/gap/tests/test_regression.py b/src/metatrain/experimental/gap/tests/test_regression.py index e4d2dda1..e2a2ee72 100644 --- a/src/metatrain/experimental/gap/tests/test_regression.py +++ b/src/metatrain/experimental/gap/tests/test_regression.py @@ -55,7 +55,7 @@ def test_regression_train_and_invariance(): } } targets, _ = read_targets(OmegaConf.create(conf)) - dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]}) + dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) target_info_dict = TargetInfoDict() target_info_dict["mtt::U0"] = TargetInfo(quantity="energy", unit="eV") @@ -132,7 +132,7 @@ def test_ethanol_regression_train_and_invariance(): } targets, _ = read_targets(OmegaConf.create(conf)) - dataset = Dataset({"system": systems, "energy": targets["energy"]}) + dataset = Dataset.from_dict({"system": systems, "energy": targets["energy"]}) hypers = copy.deepcopy(DEFAULT_HYPERS) hypers["model"]["krr"]["num_sparse_points"] = 900 diff --git a/src/metatrain/experimental/gap/tests/test_torchscript.py b/src/metatrain/experimental/gap/tests/test_torchscript.py index f0680fd4..967a8335 100644 --- a/src/metatrain/experimental/gap/tests/test_torchscript.py +++ b/src/metatrain/experimental/gap/tests/test_torchscript.py @@ -36,7 +36,7 @@ def test_torchscript(): # for system in systems: # system.types = torch.ones(len(system.types), dtype=torch.int32) - dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]}) + dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) hypers = DEFAULT_HYPERS.copy() gap = GAP(DEFAULT_HYPERS["model"], dataset_info) diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_continue.py b/src/metatrain/experimental/soap_bpnn/tests/test_continue.py index 9bd9b0e6..dac96cd9 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_continue.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_continue.py @@ -44,7 +44,7 @@ def test_continue(monkeypatch, tmp_path): } } targets, _ = read_targets(OmegaConf.create(conf)) - dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]}) + dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) hypers = DEFAULT_HYPERS.copy() hypers["training"]["num_epochs"] = 0 diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_regression.py b/src/metatrain/experimental/soap_bpnn/tests/test_regression.py index 07e3871a..0663da48 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_regression.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_regression.py @@ -70,7 +70,7 @@ def test_regression_train(): } } targets, target_info_dict = read_targets(OmegaConf.create(conf)) - dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]}) + dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) hypers = DEFAULT_HYPERS.copy() hypers["training"]["num_epochs"] = 2 diff --git a/src/metatrain/utils/data/__init__.py b/src/metatrain/utils/data/__init__.py index b0047907..0aa12843 100644 --- a/src/metatrain/utils/data/__init__.py +++ b/src/metatrain/utils/data/__init__.py @@ -7,7 +7,7 @@ get_all_targets, collate_fn, check_datasets, - group_and_join, + get_stats, ) from .readers import ( # noqa: F401 read_energy, diff --git a/src/metatrain/utils/data/dataset.py b/src/metatrain/utils/data/dataset.py index a4225c02..6debbe4e 100644 --- a/src/metatrain/utils/data/dataset.py +++ b/src/metatrain/utils/data/dataset.py @@ -3,10 +3,10 @@ from collections import UserDict from typing import Any, Dict, List, Optional, Tuple, Union -import metatensor.learn import numpy as np -import torch +from metatensor.learn.data import Dataset, group_and_join from metatensor.torch import TensorMap +from torch.utils.data import Subset from ..external_naming import to_external_name from ..units import get_gradient_units @@ -242,60 +242,7 @@ def union(self, other: "DatasetInfo") -> "DatasetInfo": return new -class Dataset: - """A version of the `metatensor.learn.Dataset` class that allows for - the use of `mtt::` prefixes in the keys of the dictionary. See - https://github.com/lab-cosmo/metatensor/issues/621. - - It is important to note that, instead of named tuples, this class - accepts and returns dictionaries. - - :param dict: A dictionary with the data to be stored in the dataset. - """ - - def __init__(self, dict: Dict): - - new_dict = {} - for key, value in dict.items(): - key = key.replace("mtt::", "mtt_") - new_dict[key] = value - - self.mts_learn_dataset = metatensor.learn.Dataset(**new_dict) - - def __getitem__(self, idx: int) -> Dict: - - mts_dataset_item = self.mts_learn_dataset[idx]._asdict() - new_dict = {} - for key, value in mts_dataset_item.items(): - key = key.replace("mtt_", "mtt::") - new_dict[key] = value - - return new_dict - - def __len__(self) -> int: - return len(self.mts_learn_dataset) - - def __iter__(self): - for i in range(len(self)): - yield self[i] - - def get_stats(self, dataset_info: DatasetInfo) -> str: - return _get_dataset_stats(self, dataset_info) - - -class Subset(torch.utils.data.Subset): - """ - A version of `torch.utils.data.Subset` containing a `get_stats` method - allowing us to print information about atomistic datasets. - """ - - def get_stats(self, dataset_info: DatasetInfo) -> str: - return _get_dataset_stats(self, dataset_info) - - -def _get_dataset_stats( - dataset: Union[Dataset, Subset], dataset_info: DatasetInfo -) -> str: +def get_stats(dataset: Union[Dataset, Subset], dataset_info: DatasetInfo) -> str: """Returns the statistics of a dataset or subset as a string.""" dataset_len = len(dataset) @@ -306,7 +253,7 @@ def _get_dataset_stats( # target_names will be used to store names of the targets, # along with their gradients target_names = [] - for key, tensor_map in dataset[0].items(): + for key, tensor_map in dataset[0]._asdict().items(): if key == "system": continue target_names.append(key) @@ -408,8 +355,8 @@ def get_all_targets(datasets: Union[Dataset, List[Dataset]]) -> List[str]: target_names = [] for dataset in datasets: for sample in dataset: - sample.pop("system") # system not needed - target_names += list(sample.keys()) + # system not needed + target_names += [key for key in sample._asdict().keys() if key != "system"] return sorted(set(target_names)) @@ -422,6 +369,7 @@ def collate_fn(batch: List[Dict[str, Any]]) -> Tuple[List, Dict[str, TensorMap]] """ collated_targets = group_and_join(batch) + collated_targets = collated_targets._asdict() systems = collated_targets.pop("system") return systems, collated_targets @@ -441,15 +389,15 @@ def check_datasets(train_datasets: List[Dataset], val_datasets: List[Dataset]): or targets that are not present in the training set """ # Check that system `dtypes` are consistent within datasets - desired_dtype = train_datasets[0][0]["system"].positions.dtype + desired_dtype = train_datasets[0][0].system.positions.dtype msg = f"`dtype` between datasets is inconsistent, found {desired_dtype} and " for train_dataset in train_datasets: - actual_dtype = train_dataset[0]["system"].positions.dtype + actual_dtype = train_dataset[0].system.positions.dtype if actual_dtype != desired_dtype: raise TypeError(f"{msg}{actual_dtype} found in `train_datasets`") for val_dataset in val_datasets: - actual_dtype = val_dataset[0]["system"].positions.dtype + actual_dtype = val_dataset[0].system.positions.dtype if actual_dtype != desired_dtype: raise TypeError(f"{msg}{actual_dtype} found in `val_datasets`") @@ -515,33 +463,3 @@ def _train_test_random_split( Subset(train_dataset, train_indices), Subset(train_dataset, test_indices), ] - - -def group_and_join( - batch: List[Dict[str, Any]], -) -> Dict[str, Any]: - """ - Same as metatenor.learn.data.group_and_join, but joins dicts and not named tuples. - - :param batch: A list of dictionaries, each containing the data for a single sample. - - :returns: A single dictionary with the data fields joined together among all - samples. - """ - data: List[Union[TensorMap, torch.Tensor]] = [] - names = batch[0].keys() - for name, f in zip(names, zip(*(item.values() for item in batch))): - if name == "sample_id": # special case, keep as is - data.append(f) - continue - - if isinstance(f[0], torch.ScriptObject) and f[0]._has_method( - "keys_to_properties" - ): # inferred metatensor.torch.TensorMap type - data.append(metatensor.torch.join(f, axis="samples")) - elif isinstance(f[0], torch.Tensor): # torch.Tensor type - data.append(torch.vstack(f)) - else: # otherwise just keep as a list - data.append(f) - - return {name: value for name, value in zip(names, data)} diff --git a/src/metatrain/utils/data/extract_targets.py b/src/metatrain/utils/data/extract_targets.py index fe39495b..ee86b29d 100644 --- a/src/metatrain/utils/data/extract_targets.py +++ b/src/metatrain/utils/data/extract_targets.py @@ -28,6 +28,7 @@ def get_targets_dict( targets_dict = {} for dataset in datasets: targets = next(iter(dataset)) + targets = targets._asdict() targets.pop("system") # system not needed # targets is now a dictionary of TensorMaps diff --git a/src/metatrain/utils/data/get_dataset.py b/src/metatrain/utils/data/get_dataset.py index 2094bae4..2f95263c 100644 --- a/src/metatrain/utils/data/get_dataset.py +++ b/src/metatrain/utils/data/get_dataset.py @@ -27,6 +27,6 @@ def get_dataset(options: DictConfig) -> Tuple[Dataset, TargetInfoDict]: reader=options["systems"]["reader"], ) targets, target_info_dictionary = read_targets(conf=options["targets"]) - dataset = Dataset({"system": systems, **targets}) + dataset = Dataset.from_dict({"system": systems, **targets}) return dataset, target_info_dictionary diff --git a/tests/utils/data/test_combine_dataloaders.py b/tests/utils/data/test_combine_dataloaders.py index 6f855059..beb78fbd 100644 --- a/tests/utils/data/test_combine_dataloaders.py +++ b/tests/utils/data/test_combine_dataloaders.py @@ -36,7 +36,7 @@ def test_without_shuffling(): } } targets, _ = read_targets(OmegaConf.create(conf)) - dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]}) + dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) dataloader_qm9 = DataLoader(dataset, batch_size=10, collate_fn=collate_fn) # will yield 10 batches of 10 @@ -56,7 +56,7 @@ def test_without_shuffling(): } targets, _ = read_targets(OmegaConf.create(conf)) targets = {"mtt::free_energy": targets["mtt::free_energy"][:10]} - dataset = Dataset( + dataset = Dataset.from_dict( {"system": systems, "mtt::free_energy": targets["mtt::free_energy"]} ) dataloader_alchemical = DataLoader(dataset, batch_size=2, collate_fn=collate_fn) @@ -94,7 +94,7 @@ def test_with_shuffling(): } } targets, _ = read_targets(OmegaConf.create(conf)) - dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]}) + dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) dataloader_qm9 = DataLoader( dataset, batch_size=10, collate_fn=collate_fn, shuffle=True ) @@ -116,7 +116,7 @@ def test_with_shuffling(): } targets, _ = read_targets(OmegaConf.create(conf)) targets = {"mtt::free_energy": targets["mtt::free_energy"][:10]} - dataset = Dataset( + dataset = Dataset.from_dict( {"system": systems, "mtt::free_energy": targets["mtt::free_energy"]} ) dataloader_alchemical = DataLoader( diff --git a/tests/utils/data/test_dataset.py b/tests/utils/data/test_dataset.py index 989c9866..24c54bac 100644 --- a/tests/utils/data/test_dataset.py +++ b/tests/utils/data/test_dataset.py @@ -13,6 +13,7 @@ collate_fn, get_all_targets, get_atomic_types, + get_stats, read_systems, read_targets, ) @@ -418,7 +419,7 @@ def test_dataset(): } } targets, _ = read_targets(OmegaConf.create(conf)) - dataset = Dataset({"system": systems, "energy": targets["energy"]}) + dataset = Dataset.from_dict({"system": systems, "energy": targets["energy"]}) dataloader = torch.utils.data.DataLoader( dataset, batch_size=10, collate_fn=collate_fn ) @@ -458,8 +459,8 @@ def test_get_atomic_types(): } targets, _ = read_targets(OmegaConf.create(conf)) targets_2, _ = read_targets(OmegaConf.create(conf_2)) - dataset = Dataset({"system": systems, **targets}) - dataset_2 = Dataset({"system": systems_2, **targets_2}) + dataset = Dataset.from_dict({"system": systems, **targets}) + dataset_2 = Dataset.from_dict({"system": systems_2, **targets_2}) assert get_atomic_types(dataset) == [1, 6, 7, 8] assert get_atomic_types(dataset_2) == [1, 6, 8] @@ -497,8 +498,8 @@ def test_get_all_targets(): } targets, _ = read_targets(OmegaConf.create(conf)) targets_2, _ = read_targets(OmegaConf.create(conf_2)) - dataset = Dataset({"system": systems, **targets}) - dataset_2 = Dataset({"system": systems_2, **targets_2}) + dataset = Dataset.from_dict({"system": systems, **targets}) + dataset_2 = Dataset.from_dict({"system": systems_2, **targets_2}) assert get_all_targets(dataset) == ["mtt::U0"] assert get_all_targets(dataset_2) == ["energy"] assert get_all_targets([dataset, dataset_2]) == ["energy", "mtt::U0"] @@ -537,19 +538,19 @@ def test_check_datasets(): targets_ethanol, _ = read_targets(OmegaConf.create(conf_ethanol)) # everything ok - train_set = Dataset({"system": systems_qm9, **targets_qm9}) - val_set = Dataset({"system": systems_qm9, **targets_qm9}) + train_set = Dataset.from_dict({"system": systems_qm9, **targets_qm9}) + val_set = Dataset.from_dict({"system": systems_qm9, **targets_qm9}) check_datasets([train_set], [val_set]) # extra species in validation dataset - train_set = Dataset({"system": systems_ethanol, **targets_qm9}) - val_set = Dataset({"system": systems_qm9, **targets_qm9}) + train_set = Dataset.from_dict({"system": systems_ethanol, **targets_qm9}) + val_set = Dataset.from_dict({"system": systems_qm9, **targets_qm9}) with pytest.raises(ValueError, match="The validation dataset has a species"): check_datasets([train_set], [val_set]) # extra targets in validation dataset - train_set = Dataset({"system": systems_qm9, **targets_qm9}) - val_set = Dataset({"system": systems_qm9, **targets_ethanol}) + train_set = Dataset.from_dict({"system": systems_qm9, **targets_qm9}) + val_set = Dataset.from_dict({"system": systems_qm9, **targets_ethanol}) with pytest.raises(ValueError, match="The validation dataset has a target"): check_datasets([train_set], [val_set]) @@ -558,7 +559,9 @@ def test_check_datasets(): targets_qm9_32bit = { k: [v.to(dtype=torch.float32) for v in l] for k, l in targets_qm9.items() } - train_set_32_bit = Dataset({"system": systems_qm9_32bit, **targets_qm9_32bit}) + train_set_32_bit = Dataset.from_dict( + {"system": systems_qm9_32bit, **targets_qm9_32bit} + ) match = ( "`dtype` between datasets is inconsistent, found torch.float64 and " @@ -592,7 +595,7 @@ def test_collate_fn(): } } targets, _ = read_targets(OmegaConf.create(conf)) - dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]}) + dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) batch = collate_fn([dataset[0], dataset[1], dataset[2]]) @@ -633,8 +636,8 @@ def test_get_stats(): } targets, _ = read_targets(OmegaConf.create(conf)) targets_2, _ = read_targets(OmegaConf.create(conf_2)) - dataset = Dataset({"system": systems, **targets}) - dataset_2 = Dataset({"system": systems_2, **targets_2}) + dataset = Dataset.from_dict({"system": systems, **targets}) + dataset_2 = Dataset.from_dict({"system": systems_2, **targets_2}) dataset_info = DatasetInfo( length_unit="angstrom", @@ -645,8 +648,8 @@ def test_get_stats(): }, ) - stats = dataset.get_stats(dataset_info) - stats_2 = dataset_2.get_stats(dataset_info) + stats = get_stats(dataset, dataset_info) + stats_2 = get_stats(dataset_2, dataset_info) assert "size 100" in stats assert "mtt::U0" in stats diff --git a/tests/utils/data/test_get_dataset.py b/tests/utils/data/test_get_dataset.py index 17fda35c..765f6a62 100644 --- a/tests/utils/data/test_get_dataset.py +++ b/tests/utils/data/test_get_dataset.py @@ -31,8 +31,8 @@ def test_get_dataset(): dataset, target_info = get_dataset(OmegaConf.create(options)) - assert "system" in dataset[0] - assert "energy" in dataset[0] + dataset[0].system + dataset[0].energy assert "energy" in target_info assert target_info["energy"].quantity == "energy" assert target_info["energy"].unit == "eV" diff --git a/tests/utils/test_composition.py b/tests/utils/test_composition.py index 1933d8d7..63a75e64 100644 --- a/tests/utils/test_composition.py +++ b/tests/utils/test_composition.py @@ -61,7 +61,7 @@ def test_calculate_composition_weights(): ) for i, e in enumerate(energies) ] - dataset = Dataset({"system": systems, "energy": energies}) + dataset = Dataset.from_dict({"system": systems, "energy": energies}) weights, atomic_types = calculate_composition_weights(dataset, "energy") diff --git a/tests/utils/test_llpr.py b/tests/utils/test_llpr.py index 8debe872..f1887ef5 100644 --- a/tests/utils/test_llpr.py +++ b/tests/utils/test_llpr.py @@ -42,7 +42,7 @@ def test_llpr(tmpdir): get_system_with_neighbor_lists(system, requested_neighbor_lists) for system in qm9_systems ] - dataset = Dataset({"system": qm9_systems, **targets}) + dataset = Dataset.from_dict({"system": qm9_systems, **targets}) dataloader = torch.utils.data.DataLoader( dataset, batch_size=10,