Skip to content

Commit

Permalink
Update dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Aug 30, 2024
1 parent e297fad commit 48f6bb3
Show file tree
Hide file tree
Showing 18 changed files with 51 additions and 117 deletions.
2 changes: 1 addition & 1 deletion examples/programmatic/llpr/llpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
get_system_with_neighbor_lists(system, requested_neighbor_lists)
for system in qm9_systems
]
dataset = Dataset({"system": qm9_systems, **targets})
dataset = Dataset.from_dict({"system": qm9_systems, **targets})

# We also load a single ethanol molecule on which we will compute properties.
# This system is loaded without targets, as we are only interested in the LPR
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def eval_model(
gradients=gradients,
)

eval_dataset = Dataset({"system": eval_systems, **eval_targets})
eval_dataset = Dataset.from_dict({"system": eval_systems, **eval_targets})

# Evaluate the model
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_regression_train():
}
}
targets, target_info_dict = read_targets(OmegaConf.create(conf))
dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]})
dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})

hypers = DEFAULT_HYPERS.copy()

Expand Down
10 changes: 5 additions & 5 deletions src/metatrain/experimental/alchemical_model/utils/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ def get_average_number_of_atoms(
"""
average_number_of_atoms = []
for dataset in datasets:
dtype = dataset[0]["system"].positions.dtype
dtype = dataset[0].system.positions.dtype
num_atoms = []
for i in range(len(dataset)):
system = dataset[i]["system"]
system = dataset[i].system
num_atoms.append(len(system))
average_number_of_atoms.append(torch.mean(torch.tensor(num_atoms, dtype=dtype)))
return torch.tensor(average_number_of_atoms)
Expand All @@ -39,9 +39,9 @@ def get_average_number_of_neighbors(
average_number_of_neighbors = []
for dataset in datasets:
num_neighbor = []
dtype = dataset[0]["system"].positions.dtype
dtype = dataset[0].system.positions.dtype
for i in range(len(dataset)):
system = dataset[i]["system"]
system = dataset[i].system
known_neighbor_lists = system.known_neighbor_lists()
if len(known_neighbor_lists) == 0:
raise ValueError(f"system {system} does not have a neighbor list")
Expand Down Expand Up @@ -94,4 +94,4 @@ def remove_composition_from_dataset(
new_systems.append(system)
new_properties.append(property)

return Dataset({"system": new_systems, property_name: new_properties})
return Dataset.from_dict({"system": new_systems, property_name: new_properties})
4 changes: 3 additions & 1 deletion src/metatrain/experimental/gap/tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def test_ethanol_regression_train_and_invariance():
}

targets, _ = read_targets(OmegaConf.create(conf))
dataset = Dataset({"system": systems[:2], "energy": targets["energy"][:2]})
dataset = Dataset.from_dict(
{"system": systems[:2], "energy": targets["energy"][:2]}
)

hypers = copy.deepcopy(DEFAULT_HYPERS)
hypers["model"]["krr"]["num_sparse_points"] = 30
Expand Down
4 changes: 2 additions & 2 deletions src/metatrain/experimental/gap/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_regression_train_and_invariance():
}
}
targets, _ = read_targets(OmegaConf.create(conf))
dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]})
dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})

target_info_dict = TargetInfoDict()
target_info_dict["mtt::U0"] = TargetInfo(quantity="energy", unit="eV")
Expand Down Expand Up @@ -132,7 +132,7 @@ def test_ethanol_regression_train_and_invariance():
}

targets, _ = read_targets(OmegaConf.create(conf))
dataset = Dataset({"system": systems, "energy": targets["energy"]})
dataset = Dataset.from_dict({"system": systems, "energy": targets["energy"]})

hypers = copy.deepcopy(DEFAULT_HYPERS)
hypers["model"]["krr"]["num_sparse_points"] = 900
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/gap/tests/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_torchscript():

# for system in systems:
# system.types = torch.ones(len(system.types), dtype=torch.int32)
dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]})
dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})

hypers = DEFAULT_HYPERS.copy()
gap = GAP(DEFAULT_HYPERS["model"], dataset_info)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_continue(monkeypatch, tmp_path):
}
}
targets, _ = read_targets(OmegaConf.create(conf))
dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]})
dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})

hypers = DEFAULT_HYPERS.copy()
hypers["training"]["num_epochs"] = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_regression_train():
}
}
targets, target_info_dict = read_targets(OmegaConf.create(conf))
dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]})
dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})

hypers = DEFAULT_HYPERS.copy()
hypers["training"]["num_epochs"] = 2
Expand Down
1 change: 0 additions & 1 deletion src/metatrain/utils/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
get_all_targets,
collate_fn,
check_datasets,
group_and_join,
)
from .readers import ( # noqa: F401
read_energy,
Expand Down
86 changes: 8 additions & 78 deletions src/metatrain/utils/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from collections import UserDict
from typing import Any, Dict, List, Optional, Tuple, Union

import metatensor.learn
import numpy as np
import torch
from metatensor.learn.data import Dataset, group_and_join
from metatensor.torch import TensorMap

from ..external_naming import to_external_name
Expand Down Expand Up @@ -242,47 +242,6 @@ def union(self, other: "DatasetInfo") -> "DatasetInfo":
return new


class Dataset:
"""A version of the `metatensor.learn.Dataset` class that allows for
the use of `mtt::` prefixes in the keys of the dictionary. See
https://github.com/lab-cosmo/metatensor/issues/621.
It is important to note that, instead of named tuples, this class
accepts and returns dictionaries.
:param dict: A dictionary with the data to be stored in the dataset.
"""

def __init__(self, dict: Dict):

new_dict = {}
for key, value in dict.items():
key = key.replace("mtt::", "mtt_")
new_dict[key] = value

self.mts_learn_dataset = metatensor.learn.Dataset(**new_dict)

def __getitem__(self, idx: int) -> Dict:

mts_dataset_item = self.mts_learn_dataset[idx]._asdict()
new_dict = {}
for key, value in mts_dataset_item.items():
key = key.replace("mtt_", "mtt::")
new_dict[key] = value

return new_dict

def __len__(self) -> int:
return len(self.mts_learn_dataset)

def __iter__(self):
for i in range(len(self)):
yield self[i]

def get_stats(self, dataset_info: DatasetInfo) -> str:
return _get_dataset_stats(self, dataset_info)


class Subset(torch.utils.data.Subset):
"""
A version of `torch.utils.data.Subset` containing a `get_stats` method
Expand All @@ -306,7 +265,7 @@ def _get_dataset_stats(
# target_names will be used to store names of the targets,
# along with their gradients
target_names = []
for key, tensor_map in dataset[0].items():
for key, tensor_map in dataset[0]._asdict().items():
if key == "system":
continue
target_names.append(key)
Expand Down Expand Up @@ -408,8 +367,8 @@ def get_all_targets(datasets: Union[Dataset, List[Dataset]]) -> List[str]:
target_names = []
for dataset in datasets:
for sample in dataset:
sample.pop("system") # system not needed
target_names += list(sample.keys())
# system not needed
target_names += [key for key in sample._asdict().keys() if key != "system"]

return sorted(set(target_names))

Expand All @@ -422,6 +381,7 @@ def collate_fn(batch: List[Dict[str, Any]]) -> Tuple[List, Dict[str, TensorMap]]
"""

collated_targets = group_and_join(batch)
collated_targets = collated_targets._asdict()
systems = collated_targets.pop("system")
return systems, collated_targets

Expand All @@ -441,15 +401,15 @@ def check_datasets(train_datasets: List[Dataset], val_datasets: List[Dataset]):
or targets that are not present in the training set
"""
# Check that system `dtypes` are consistent within datasets
desired_dtype = train_datasets[0][0]["system"].positions.dtype
desired_dtype = train_datasets[0][0].system.positions.dtype
msg = f"`dtype` between datasets is inconsistent, found {desired_dtype} and "
for train_dataset in train_datasets:
actual_dtype = train_dataset[0]["system"].positions.dtype
actual_dtype = train_dataset[0].system.positions.dtype
if actual_dtype != desired_dtype:
raise TypeError(f"{msg}{actual_dtype} found in `train_datasets`")

for val_dataset in val_datasets:
actual_dtype = val_dataset[0]["system"].positions.dtype
actual_dtype = val_dataset[0].system.positions.dtype
if actual_dtype != desired_dtype:
raise TypeError(f"{msg}{actual_dtype} found in `val_datasets`")

Expand Down Expand Up @@ -515,33 +475,3 @@ def _train_test_random_split(
Subset(train_dataset, train_indices),
Subset(train_dataset, test_indices),
]


def group_and_join(
batch: List[Dict[str, Any]],
) -> Dict[str, Any]:
"""
Same as metatenor.learn.data.group_and_join, but joins dicts and not named tuples.
:param batch: A list of dictionaries, each containing the data for a single sample.
:returns: A single dictionary with the data fields joined together among all
samples.
"""
data: List[Union[TensorMap, torch.Tensor]] = []
names = batch[0].keys()
for name, f in zip(names, zip(*(item.values() for item in batch))):
if name == "sample_id": # special case, keep as is
data.append(f)
continue

if isinstance(f[0], torch.ScriptObject) and f[0]._has_method(
"keys_to_properties"
): # inferred metatensor.torch.TensorMap type
data.append(metatensor.torch.join(f, axis="samples"))
elif isinstance(f[0], torch.Tensor): # torch.Tensor type
data.append(torch.vstack(f))
else: # otherwise just keep as a list
data.append(f)

return {name: value for name, value in zip(names, data)}
1 change: 1 addition & 0 deletions src/metatrain/utils/data/extract_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def get_targets_dict(
targets_dict = {}
for dataset in datasets:
targets = next(iter(dataset))
targets = targets._asdict()
targets.pop("system") # system not needed

# targets is now a dictionary of TensorMaps
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/utils/data/get_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ def get_dataset(options: DictConfig) -> Tuple[Dataset, TargetInfoDict]:
reader=options["systems"]["reader"],
)
targets, target_info_dictionary = read_targets(conf=options["targets"])
dataset = Dataset({"system": systems, **targets})
dataset = Dataset.from_dict({"system": systems, **targets})

return dataset, target_info_dictionary
8 changes: 4 additions & 4 deletions tests/utils/data/test_combine_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_without_shuffling():
}
}
targets, _ = read_targets(OmegaConf.create(conf))
dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]})
dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})
dataloader_qm9 = DataLoader(dataset, batch_size=10, collate_fn=collate_fn)
# will yield 10 batches of 10

Expand All @@ -56,7 +56,7 @@ def test_without_shuffling():
}
targets, _ = read_targets(OmegaConf.create(conf))
targets = {"mtt::free_energy": targets["mtt::free_energy"][:10]}
dataset = Dataset(
dataset = Dataset.from_dict(
{"system": systems, "mtt::free_energy": targets["mtt::free_energy"]}
)
dataloader_alchemical = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)
Expand Down Expand Up @@ -94,7 +94,7 @@ def test_with_shuffling():
}
}
targets, _ = read_targets(OmegaConf.create(conf))
dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]})
dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})
dataloader_qm9 = DataLoader(
dataset, batch_size=10, collate_fn=collate_fn, shuffle=True
)
Expand All @@ -116,7 +116,7 @@ def test_with_shuffling():
}
targets, _ = read_targets(OmegaConf.create(conf))
targets = {"mtt::free_energy": targets["mtt::free_energy"][:10]}
dataset = Dataset(
dataset = Dataset.from_dict(
{"system": systems, "mtt::free_energy": targets["mtt::free_energy"]}
)
dataloader_alchemical = DataLoader(
Expand Down
Loading

0 comments on commit 48f6bb3

Please sign in to comment.