Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update metatensor #331

Merged
merged 3 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/programmatic/llpr/llpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
get_system_with_neighbor_lists(system, requested_neighbor_lists)
for system in qm9_systems
]
dataset = Dataset({"system": qm9_systems, **targets})
dataset = Dataset.from_dict({"system": qm9_systems, **targets})

# We also load a single ethanol molecule on which we will compute properties.
# This system is loaded without targets, as we are only interested in the LPR
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ authors = [{name = "metatrain developers"}]

dependencies = [
"ase < 3.23.0",
"metatensor-learn==0.2.2",
"metatensor-operations==0.2.1",
"metatensor-torch==0.5.3",
"metatensor-learn==0.2.3",
Luthaf marked this conversation as resolved.
Show resolved Hide resolved
"metatensor-operations==0.2.3",
"metatensor-torch==0.5.4",
"jsonschema",
"omegaconf",
"python-hostlist",
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def eval_model(
gradients=gradients,
)

eval_dataset = Dataset({"system": eval_systems, **eval_targets})
eval_dataset = Dataset.from_dict({"system": eval_systems, **eval_targets})

# Evaluate the model
try:
Expand Down
16 changes: 12 additions & 4 deletions src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@

from .. import PACKAGE_ROOT
from ..utils.architectures import check_architecture_options, get_default_hypers
from ..utils.data import DatasetInfo, TargetInfoDict, get_atomic_types, get_dataset
from ..utils.data import (
DatasetInfo,
TargetInfoDict,
get_atomic_types,
get_dataset,
get_stats,
)
from ..utils.data.dataset import _train_test_random_split
from ..utils.devices import pick_devices
from ..utils.distributed.logging import is_main_process
Expand Down Expand Up @@ -290,7 +296,7 @@ def train_model(
else:
index = f" {i}"
logger.info(
f"Training dataset{index}:\n {train_dataset.get_stats(dataset_info)}"
f"Training dataset{index}:\n {get_stats(train_dataset, dataset_info)}"
)

for i, val_dataset in enumerate(val_datasets):
Expand All @@ -299,15 +305,17 @@ def train_model(
else:
index = f" {i}"
logger.info(
f"Validation dataset{index}:\n {val_dataset.get_stats(dataset_info)}"
f"Validation dataset{index}:\n {get_stats(val_dataset, dataset_info)}"
)

for i, test_dataset in enumerate(test_datasets):
if len(test_datasets) == 1:
index = ""
else:
index = f" {i}"
logger.info(f"Test dataset{index}:\n {test_dataset.get_stats(dataset_info)}")
logger.info(
f"Test dataset{index}:\n {get_stats(test_dataset, dataset_info)}"
)

###########################
# SAVE EXPANDED OPTIONS ###
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_regression_train():
}
}
targets, target_info_dict = read_targets(OmegaConf.create(conf))
dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]})
dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})

hypers = DEFAULT_HYPERS.copy()

Expand Down
10 changes: 5 additions & 5 deletions src/metatrain/experimental/alchemical_model/utils/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ def get_average_number_of_atoms(
"""
average_number_of_atoms = []
for dataset in datasets:
dtype = dataset[0]["system"].positions.dtype
dtype = dataset[0].system.positions.dtype
num_atoms = []
for i in range(len(dataset)):
system = dataset[i]["system"]
system = dataset[i].system
num_atoms.append(len(system))
average_number_of_atoms.append(torch.mean(torch.tensor(num_atoms, dtype=dtype)))
return torch.tensor(average_number_of_atoms)
Expand All @@ -39,9 +39,9 @@ def get_average_number_of_neighbors(
average_number_of_neighbors = []
for dataset in datasets:
num_neighbor = []
dtype = dataset[0]["system"].positions.dtype
dtype = dataset[0].system.positions.dtype
for i in range(len(dataset)):
system = dataset[i]["system"]
system = dataset[i].system
known_neighbor_lists = system.known_neighbor_lists()
if len(known_neighbor_lists) == 0:
raise ValueError(f"system {system} does not have a neighbor list")
Expand Down Expand Up @@ -94,4 +94,4 @@ def remove_composition_from_dataset(
new_systems.append(system)
new_properties.append(property)

return Dataset({"system": new_systems, property_name: new_properties})
return Dataset.from_dict({"system": new_systems, property_name: new_properties})
4 changes: 3 additions & 1 deletion src/metatrain/experimental/gap/tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def test_ethanol_regression_train_and_invariance():
}

targets, _ = read_targets(OmegaConf.create(conf))
dataset = Dataset({"system": systems[:2], "energy": targets["energy"][:2]})
dataset = Dataset.from_dict(
{"system": systems[:2], "energy": targets["energy"][:2]}
)

hypers = copy.deepcopy(DEFAULT_HYPERS)
hypers["model"]["krr"]["num_sparse_points"] = 30
Expand Down
4 changes: 2 additions & 2 deletions src/metatrain/experimental/gap/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_regression_train_and_invariance():
}
}
targets, _ = read_targets(OmegaConf.create(conf))
dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]})
dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})

target_info_dict = TargetInfoDict()
target_info_dict["mtt::U0"] = TargetInfo(quantity="energy", unit="eV")
Expand Down Expand Up @@ -132,7 +132,7 @@ def test_ethanol_regression_train_and_invariance():
}

targets, _ = read_targets(OmegaConf.create(conf))
dataset = Dataset({"system": systems, "energy": targets["energy"]})
dataset = Dataset.from_dict({"system": systems, "energy": targets["energy"]})

hypers = copy.deepcopy(DEFAULT_HYPERS)
hypers["model"]["krr"]["num_sparse_points"] = 900
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/gap/tests/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_torchscript():

# for system in systems:
# system.types = torch.ones(len(system.types), dtype=torch.int32)
dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]})
dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})

hypers = DEFAULT_HYPERS.copy()
gap = GAP(DEFAULT_HYPERS["model"], dataset_info)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test_predictions_compatibility(cutoff):
"neighbors_pos": batch.neighbors_pos,
}

pet = model._module.pet
pet = model.module.pet

pet_prediction = pet.forward(batch_dict)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_continue(monkeypatch, tmp_path):
}
}
targets, _ = read_targets(OmegaConf.create(conf))
dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]})
dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})

hypers = DEFAULT_HYPERS.copy()
hypers["training"]["num_epochs"] = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_regression_train():
}
}
targets, target_info_dict = read_targets(OmegaConf.create(conf))
dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]})
dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})

hypers = DEFAULT_HYPERS.copy()
hypers["training"]["num_epochs"] = 2
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/utils/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
get_all_targets,
collate_fn,
check_datasets,
group_and_join,
get_stats,
)
from .readers import ( # noqa: F401
read_energy,
Expand Down
102 changes: 10 additions & 92 deletions src/metatrain/utils/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from collections import UserDict
from typing import Any, Dict, List, Optional, Tuple, Union

import metatensor.learn
import numpy as np
import torch
from metatensor.learn.data import Dataset, group_and_join
from metatensor.torch import TensorMap
from torch.utils.data import Subset

from ..external_naming import to_external_name
from ..units import get_gradient_units
Expand Down Expand Up @@ -242,60 +242,7 @@ def union(self, other: "DatasetInfo") -> "DatasetInfo":
return new


class Dataset:
"""A version of the `metatensor.learn.Dataset` class that allows for
the use of `mtt::` prefixes in the keys of the dictionary. See
https://github.com/lab-cosmo/metatensor/issues/621.

It is important to note that, instead of named tuples, this class
accepts and returns dictionaries.

:param dict: A dictionary with the data to be stored in the dataset.
"""

def __init__(self, dict: Dict):

new_dict = {}
for key, value in dict.items():
key = key.replace("mtt::", "mtt_")
new_dict[key] = value

self.mts_learn_dataset = metatensor.learn.Dataset(**new_dict)

def __getitem__(self, idx: int) -> Dict:

mts_dataset_item = self.mts_learn_dataset[idx]._asdict()
new_dict = {}
for key, value in mts_dataset_item.items():
key = key.replace("mtt_", "mtt::")
new_dict[key] = value

return new_dict

def __len__(self) -> int:
return len(self.mts_learn_dataset)

def __iter__(self):
for i in range(len(self)):
yield self[i]

def get_stats(self, dataset_info: DatasetInfo) -> str:
return _get_dataset_stats(self, dataset_info)


class Subset(torch.utils.data.Subset):
"""
A version of `torch.utils.data.Subset` containing a `get_stats` method
allowing us to print information about atomistic datasets.
"""

def get_stats(self, dataset_info: DatasetInfo) -> str:
return _get_dataset_stats(self, dataset_info)


def _get_dataset_stats(
dataset: Union[Dataset, Subset], dataset_info: DatasetInfo
) -> str:
def get_stats(dataset: Union[Dataset, Subset], dataset_info: DatasetInfo) -> str:
"""Returns the statistics of a dataset or subset as a string."""

dataset_len = len(dataset)
Expand All @@ -306,7 +253,7 @@ def _get_dataset_stats(
# target_names will be used to store names of the targets,
# along with their gradients
target_names = []
for key, tensor_map in dataset[0].items():
for key, tensor_map in dataset[0]._asdict().items():
if key == "system":
continue
target_names.append(key)
Expand Down Expand Up @@ -408,8 +355,8 @@ def get_all_targets(datasets: Union[Dataset, List[Dataset]]) -> List[str]:
target_names = []
for dataset in datasets:
for sample in dataset:
sample.pop("system") # system not needed
target_names += list(sample.keys())
# system not needed
target_names += [key for key in sample._asdict().keys() if key != "system"]

return sorted(set(target_names))

Expand All @@ -422,6 +369,7 @@ def collate_fn(batch: List[Dict[str, Any]]) -> Tuple[List, Dict[str, TensorMap]]
"""

collated_targets = group_and_join(batch)
collated_targets = collated_targets._asdict()
systems = collated_targets.pop("system")
return systems, collated_targets

Expand All @@ -441,15 +389,15 @@ def check_datasets(train_datasets: List[Dataset], val_datasets: List[Dataset]):
or targets that are not present in the training set
"""
# Check that system `dtypes` are consistent within datasets
desired_dtype = train_datasets[0][0]["system"].positions.dtype
desired_dtype = train_datasets[0][0].system.positions.dtype
msg = f"`dtype` between datasets is inconsistent, found {desired_dtype} and "
for train_dataset in train_datasets:
actual_dtype = train_dataset[0]["system"].positions.dtype
actual_dtype = train_dataset[0].system.positions.dtype
if actual_dtype != desired_dtype:
raise TypeError(f"{msg}{actual_dtype} found in `train_datasets`")

for val_dataset in val_datasets:
actual_dtype = val_dataset[0]["system"].positions.dtype
actual_dtype = val_dataset[0].system.positions.dtype
if actual_dtype != desired_dtype:
raise TypeError(f"{msg}{actual_dtype} found in `val_datasets`")

Expand Down Expand Up @@ -515,33 +463,3 @@ def _train_test_random_split(
Subset(train_dataset, train_indices),
Subset(train_dataset, test_indices),
]


def group_and_join(
batch: List[Dict[str, Any]],
) -> Dict[str, Any]:
"""
Same as metatenor.learn.data.group_and_join, but joins dicts and not named tuples.

:param batch: A list of dictionaries, each containing the data for a single sample.

:returns: A single dictionary with the data fields joined together among all
samples.
"""
data: List[Union[TensorMap, torch.Tensor]] = []
names = batch[0].keys()
for name, f in zip(names, zip(*(item.values() for item in batch))):
if name == "sample_id": # special case, keep as is
data.append(f)
continue

if isinstance(f[0], torch.ScriptObject) and f[0]._has_method(
"keys_to_properties"
): # inferred metatensor.torch.TensorMap type
data.append(metatensor.torch.join(f, axis="samples"))
elif isinstance(f[0], torch.Tensor): # torch.Tensor type
data.append(torch.vstack(f))
else: # otherwise just keep as a list
data.append(f)

return {name: value for name, value in zip(names, data)}
1 change: 1 addition & 0 deletions src/metatrain/utils/data/extract_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def get_targets_dict(
targets_dict = {}
for dataset in datasets:
targets = next(iter(dataset))
targets = targets._asdict()
targets.pop("system") # system not needed

# targets is now a dictionary of TensorMaps
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/utils/data/get_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ def get_dataset(options: DictConfig) -> Tuple[Dataset, TargetInfoDict]:
reader=options["systems"]["reader"],
)
targets, target_info_dictionary = read_targets(conf=options["targets"])
dataset = Dataset({"system": systems, **targets})
dataset = Dataset.from_dict({"system": systems, **targets})

return dataset, target_info_dictionary
2 changes: 1 addition & 1 deletion src/metatrain/utils/llpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
super().__init__()

self.model = model
self.ll_feat_size = self.model._module.last_layer_feature_size
self.ll_feat_size = self.model.module.last_layer_feature_size

# update capabilities: now we have additional outputs for the uncertainty
old_capabilities = self.model.capabilities()
Expand Down
Loading