Skip to content

Commit

Permalink
Update metatensor (#331)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Sep 2, 2024
1 parent 6152ba6 commit 38efd99
Show file tree
Hide file tree
Showing 22 changed files with 74 additions and 142 deletions.
2 changes: 1 addition & 1 deletion examples/programmatic/llpr/llpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
get_system_with_neighbor_lists(system, requested_neighbor_lists)
for system in qm9_systems
]
dataset = Dataset({"system": qm9_systems, **targets})
dataset = Dataset.from_dict({"system": qm9_systems, **targets})

# We also load a single ethanol molecule on which we will compute properties.
# This system is loaded without targets, as we are only interested in the LPR
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ authors = [{name = "metatrain developers"}]

dependencies = [
"ase < 3.23.0",
"metatensor-learn==0.2.2",
"metatensor-operations==0.2.1",
"metatensor-torch==0.5.3",
"metatensor-learn==0.2.3",
"metatensor-operations==0.2.3",
"metatensor-torch==0.5.4",
"jsonschema",
"omegaconf",
"python-hostlist",
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def eval_model(
gradients=gradients,
)

eval_dataset = Dataset({"system": eval_systems, **eval_targets})
eval_dataset = Dataset.from_dict({"system": eval_systems, **eval_targets})

# Evaluate the model
try:
Expand Down
16 changes: 12 additions & 4 deletions src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@

from .. import PACKAGE_ROOT
from ..utils.architectures import check_architecture_options, get_default_hypers
from ..utils.data import DatasetInfo, TargetInfoDict, get_atomic_types, get_dataset
from ..utils.data import (
DatasetInfo,
TargetInfoDict,
get_atomic_types,
get_dataset,
get_stats,
)
from ..utils.data.dataset import _train_test_random_split
from ..utils.devices import pick_devices
from ..utils.distributed.logging import is_main_process
Expand Down Expand Up @@ -290,7 +296,7 @@ def train_model(
else:
index = f" {i}"
logger.info(
f"Training dataset{index}:\n {train_dataset.get_stats(dataset_info)}"
f"Training dataset{index}:\n {get_stats(train_dataset, dataset_info)}"
)

for i, val_dataset in enumerate(val_datasets):
Expand All @@ -299,15 +305,17 @@ def train_model(
else:
index = f" {i}"
logger.info(
f"Validation dataset{index}:\n {val_dataset.get_stats(dataset_info)}"
f"Validation dataset{index}:\n {get_stats(val_dataset, dataset_info)}"
)

for i, test_dataset in enumerate(test_datasets):
if len(test_datasets) == 1:
index = ""
else:
index = f" {i}"
logger.info(f"Test dataset{index}:\n {test_dataset.get_stats(dataset_info)}")
logger.info(
f"Test dataset{index}:\n {get_stats(test_dataset, dataset_info)}"
)

###########################
# SAVE EXPANDED OPTIONS ###
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_regression_train():
}
}
targets, target_info_dict = read_targets(OmegaConf.create(conf))
dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]})
dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})

hypers = DEFAULT_HYPERS.copy()

Expand Down
10 changes: 5 additions & 5 deletions src/metatrain/experimental/alchemical_model/utils/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ def get_average_number_of_atoms(
"""
average_number_of_atoms = []
for dataset in datasets:
dtype = dataset[0]["system"].positions.dtype
dtype = dataset[0].system.positions.dtype
num_atoms = []
for i in range(len(dataset)):
system = dataset[i]["system"]
system = dataset[i].system
num_atoms.append(len(system))
average_number_of_atoms.append(torch.mean(torch.tensor(num_atoms, dtype=dtype)))
return torch.tensor(average_number_of_atoms)
Expand All @@ -39,9 +39,9 @@ def get_average_number_of_neighbors(
average_number_of_neighbors = []
for dataset in datasets:
num_neighbor = []
dtype = dataset[0]["system"].positions.dtype
dtype = dataset[0].system.positions.dtype
for i in range(len(dataset)):
system = dataset[i]["system"]
system = dataset[i].system
known_neighbor_lists = system.known_neighbor_lists()
if len(known_neighbor_lists) == 0:
raise ValueError(f"system {system} does not have a neighbor list")
Expand Down Expand Up @@ -94,4 +94,4 @@ def remove_composition_from_dataset(
new_systems.append(system)
new_properties.append(property)

return Dataset({"system": new_systems, property_name: new_properties})
return Dataset.from_dict({"system": new_systems, property_name: new_properties})
4 changes: 3 additions & 1 deletion src/metatrain/experimental/gap/tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def test_ethanol_regression_train_and_invariance():
}

targets, _ = read_targets(OmegaConf.create(conf))
dataset = Dataset({"system": systems[:2], "energy": targets["energy"][:2]})
dataset = Dataset.from_dict(
{"system": systems[:2], "energy": targets["energy"][:2]}
)

hypers = copy.deepcopy(DEFAULT_HYPERS)
hypers["model"]["krr"]["num_sparse_points"] = 30
Expand Down
4 changes: 2 additions & 2 deletions src/metatrain/experimental/gap/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_regression_train_and_invariance():
}
}
targets, _ = read_targets(OmegaConf.create(conf))
dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]})
dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})

target_info_dict = TargetInfoDict()
target_info_dict["mtt::U0"] = TargetInfo(quantity="energy", unit="eV")
Expand Down Expand Up @@ -132,7 +132,7 @@ def test_ethanol_regression_train_and_invariance():
}

targets, _ = read_targets(OmegaConf.create(conf))
dataset = Dataset({"system": systems, "energy": targets["energy"]})
dataset = Dataset.from_dict({"system": systems, "energy": targets["energy"]})

hypers = copy.deepcopy(DEFAULT_HYPERS)
hypers["model"]["krr"]["num_sparse_points"] = 900
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/gap/tests/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_torchscript():

# for system in systems:
# system.types = torch.ones(len(system.types), dtype=torch.int32)
dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]})
dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})

hypers = DEFAULT_HYPERS.copy()
gap = GAP(DEFAULT_HYPERS["model"], dataset_info)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test_predictions_compatibility(cutoff):
"neighbors_pos": batch.neighbors_pos,
}

pet = model._module.pet
pet = model.module.pet

pet_prediction = pet.forward(batch_dict)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_continue(monkeypatch, tmp_path):
}
}
targets, _ = read_targets(OmegaConf.create(conf))
dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]})
dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})

hypers = DEFAULT_HYPERS.copy()
hypers["training"]["num_epochs"] = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_regression_train():
}
}
targets, target_info_dict = read_targets(OmegaConf.create(conf))
dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]})
dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})

hypers = DEFAULT_HYPERS.copy()
hypers["training"]["num_epochs"] = 2
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/utils/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
get_all_targets,
collate_fn,
check_datasets,
group_and_join,
get_stats,
)
from .readers import ( # noqa: F401
read_energy,
Expand Down
102 changes: 10 additions & 92 deletions src/metatrain/utils/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from collections import UserDict
from typing import Any, Dict, List, Optional, Tuple, Union

import metatensor.learn
import numpy as np
import torch
from metatensor.learn.data import Dataset, group_and_join
from metatensor.torch import TensorMap
from torch.utils.data import Subset

from ..external_naming import to_external_name
from ..units import get_gradient_units
Expand Down Expand Up @@ -242,60 +242,7 @@ def union(self, other: "DatasetInfo") -> "DatasetInfo":
return new


class Dataset:
"""A version of the `metatensor.learn.Dataset` class that allows for
the use of `mtt::` prefixes in the keys of the dictionary. See
https://github.com/lab-cosmo/metatensor/issues/621.
It is important to note that, instead of named tuples, this class
accepts and returns dictionaries.
:param dict: A dictionary with the data to be stored in the dataset.
"""

def __init__(self, dict: Dict):

new_dict = {}
for key, value in dict.items():
key = key.replace("mtt::", "mtt_")
new_dict[key] = value

self.mts_learn_dataset = metatensor.learn.Dataset(**new_dict)

def __getitem__(self, idx: int) -> Dict:

mts_dataset_item = self.mts_learn_dataset[idx]._asdict()
new_dict = {}
for key, value in mts_dataset_item.items():
key = key.replace("mtt_", "mtt::")
new_dict[key] = value

return new_dict

def __len__(self) -> int:
return len(self.mts_learn_dataset)

def __iter__(self):
for i in range(len(self)):
yield self[i]

def get_stats(self, dataset_info: DatasetInfo) -> str:
return _get_dataset_stats(self, dataset_info)


class Subset(torch.utils.data.Subset):
"""
A version of `torch.utils.data.Subset` containing a `get_stats` method
allowing us to print information about atomistic datasets.
"""

def get_stats(self, dataset_info: DatasetInfo) -> str:
return _get_dataset_stats(self, dataset_info)


def _get_dataset_stats(
dataset: Union[Dataset, Subset], dataset_info: DatasetInfo
) -> str:
def get_stats(dataset: Union[Dataset, Subset], dataset_info: DatasetInfo) -> str:
"""Returns the statistics of a dataset or subset as a string."""

dataset_len = len(dataset)
Expand All @@ -306,7 +253,7 @@ def _get_dataset_stats(
# target_names will be used to store names of the targets,
# along with their gradients
target_names = []
for key, tensor_map in dataset[0].items():
for key, tensor_map in dataset[0]._asdict().items():
if key == "system":
continue
target_names.append(key)
Expand Down Expand Up @@ -408,8 +355,8 @@ def get_all_targets(datasets: Union[Dataset, List[Dataset]]) -> List[str]:
target_names = []
for dataset in datasets:
for sample in dataset:
sample.pop("system") # system not needed
target_names += list(sample.keys())
# system not needed
target_names += [key for key in sample._asdict().keys() if key != "system"]

return sorted(set(target_names))

Expand All @@ -422,6 +369,7 @@ def collate_fn(batch: List[Dict[str, Any]]) -> Tuple[List, Dict[str, TensorMap]]
"""

collated_targets = group_and_join(batch)
collated_targets = collated_targets._asdict()
systems = collated_targets.pop("system")
return systems, collated_targets

Expand All @@ -441,15 +389,15 @@ def check_datasets(train_datasets: List[Dataset], val_datasets: List[Dataset]):
or targets that are not present in the training set
"""
# Check that system `dtypes` are consistent within datasets
desired_dtype = train_datasets[0][0]["system"].positions.dtype
desired_dtype = train_datasets[0][0].system.positions.dtype
msg = f"`dtype` between datasets is inconsistent, found {desired_dtype} and "
for train_dataset in train_datasets:
actual_dtype = train_dataset[0]["system"].positions.dtype
actual_dtype = train_dataset[0].system.positions.dtype
if actual_dtype != desired_dtype:
raise TypeError(f"{msg}{actual_dtype} found in `train_datasets`")

for val_dataset in val_datasets:
actual_dtype = val_dataset[0]["system"].positions.dtype
actual_dtype = val_dataset[0].system.positions.dtype
if actual_dtype != desired_dtype:
raise TypeError(f"{msg}{actual_dtype} found in `val_datasets`")

Expand Down Expand Up @@ -515,33 +463,3 @@ def _train_test_random_split(
Subset(train_dataset, train_indices),
Subset(train_dataset, test_indices),
]


def group_and_join(
batch: List[Dict[str, Any]],
) -> Dict[str, Any]:
"""
Same as metatenor.learn.data.group_and_join, but joins dicts and not named tuples.
:param batch: A list of dictionaries, each containing the data for a single sample.
:returns: A single dictionary with the data fields joined together among all
samples.
"""
data: List[Union[TensorMap, torch.Tensor]] = []
names = batch[0].keys()
for name, f in zip(names, zip(*(item.values() for item in batch))):
if name == "sample_id": # special case, keep as is
data.append(f)
continue

if isinstance(f[0], torch.ScriptObject) and f[0]._has_method(
"keys_to_properties"
): # inferred metatensor.torch.TensorMap type
data.append(metatensor.torch.join(f, axis="samples"))
elif isinstance(f[0], torch.Tensor): # torch.Tensor type
data.append(torch.vstack(f))
else: # otherwise just keep as a list
data.append(f)

return {name: value for name, value in zip(names, data)}
1 change: 1 addition & 0 deletions src/metatrain/utils/data/extract_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def get_targets_dict(
targets_dict = {}
for dataset in datasets:
targets = next(iter(dataset))
targets = targets._asdict()
targets.pop("system") # system not needed

# targets is now a dictionary of TensorMaps
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/utils/data/get_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ def get_dataset(options: DictConfig) -> Tuple[Dataset, TargetInfoDict]:
reader=options["systems"]["reader"],
)
targets, target_info_dictionary = read_targets(conf=options["targets"])
dataset = Dataset({"system": systems, **targets})
dataset = Dataset.from_dict({"system": systems, **targets})

return dataset, target_info_dictionary
2 changes: 1 addition & 1 deletion src/metatrain/utils/llpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
super().__init__()

self.model = model
self.ll_feat_size = self.model._module.last_layer_feature_size
self.ll_feat_size = self.model.module.last_layer_feature_size

# update capabilities: now we have additional outputs for the uncertainty
old_capabilities = self.model.capabilities()
Expand Down
Loading

0 comments on commit 38efd99

Please sign in to comment.