Skip to content

Commit

Permalink
Rewrite target and dataset info classes (#217)
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri authored May 30, 2024
1 parent 115f13e commit bcfd69b
Show file tree
Hide file tree
Showing 40 changed files with 716 additions and 404 deletions.
27 changes: 9 additions & 18 deletions src/metatensor/models/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..utils.data import (
Dataset,
TargetInfo,
TargetInfoDict,
collate_fn,
read_systems,
read_targets,
Expand Down Expand Up @@ -137,7 +138,7 @@ def _concatenate_tensormaps(
def _eval_targets(
model: Union[MetatensorAtomisticModel, torch.jit._script.RecursiveScriptModule],
dataset: Union[Dataset, torch.utils.data.Subset],
options: Dict[str, TargetInfo],
options: TargetInfoDict,
return_predictions: bool,
) -> Optional[Dict[str, TensorMap]]:
"""Evaluates an exported model on a dataset and prints the RMSEs for each target.
Expand Down Expand Up @@ -253,34 +254,24 @@ def eval_model(
if hasattr(options, "targets"):
# in this case, we only evaluate the targets specified in the options
# and we calculate RMSEs
eval_targets = read_targets(options["targets"], dtype=dtype)
eval_outputs = {
key: TargetInfo(
quantity=model.capabilities().outputs[key].quantity,
unit=model.capabilities().outputs[key].unit,
per_atom=False, # TODO: allow the user to specify this
gradients=tensormaps[0].block().gradients_list(),
)
for key, tensormaps in eval_targets.items()
}
eval_targets, eval_info_dict = read_targets(options["targets"], dtype=dtype)
else:
# in this case, we have no targets: we evaluate everything
# (but we don't/can't calculate RMSEs)
# TODO: allow the user to specify which outputs to evaluate
eval_targets = {}
gradients = ["positions"]
eval_info_dict = TargetInfoDict()
gradients = {"positions"}
if all(not torch.all(system.cell == 0) for system in eval_systems):
# only add strain if all structures have cells
gradients.append("strain")
eval_outputs = {
key: TargetInfo(
gradients.add("strain")
for key in model.capabilities().outputs.keys():
eval_info_dict[key] = TargetInfo(
quantity=model.capabilities().outputs[key].quantity,
unit=model.capabilities().outputs[key].unit,
per_atom=False, # TODO: allow the user to specify this
gradients=gradients,
)
for key in model.capabilities().outputs.keys()
}

eval_dataset = Dataset({"system": eval_systems, **eval_targets})

Expand All @@ -289,7 +280,7 @@ def eval_model(
predictions = _eval_targets(
model=model,
dataset=eval_dataset,
options=eval_outputs,
options=eval_info_dict,
return_predictions=True,
)
except Exception as e:
Expand Down
47 changes: 16 additions & 31 deletions src/metatensor/models/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import random
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, Optional, Union

import numpy as np
import torch
Expand All @@ -15,7 +15,7 @@
from ..utils.data import (
Dataset,
DatasetInfo,
TargetInfo,
TargetInfoDict,
get_atomic_types,
read_systems,
read_targets,
Expand Down Expand Up @@ -179,13 +179,18 @@ def train_model(
check_options_list(train_options_list)

train_datasets = []
target_infos = TargetInfoDict()
for train_options in train_options_list:
train_systems = read_systems(
filename=train_options["systems"]["read_from"],
fileformat=train_options["systems"]["file_format"],
dtype=dtype,
)
train_targets = read_targets(conf=train_options["targets"], dtype=dtype)
train_targets, target_info_dictionary = read_targets(
conf=train_options["targets"], dtype=dtype
)

target_infos.update(target_info_dictionary)
train_datasets.append(Dataset({"system": train_systems, **train_targets}))

train_size = 1.0
Expand Down Expand Up @@ -240,7 +245,7 @@ def train_model(
fileformat=test_options["systems"]["file_format"],
dtype=dtype,
)
test_targets = read_targets(conf=test_options["targets"], dtype=dtype)
test_targets, _ = read_targets(conf=test_options["targets"], dtype=dtype)
test_dataset = Dataset({"system": test_systems, **test_targets})
test_datasets.append(test_dataset)

Expand Down Expand Up @@ -295,7 +300,7 @@ def train_model(
fileformat=validation_options["systems"]["file_format"],
dtype=dtype,
)
validation_targets = read_targets(
validation_targets, _ = read_targets(
conf=validation_options["targets"], dtype=dtype
)
validation_dataset = Dataset(
Expand All @@ -315,34 +320,14 @@ def train_model(
# CREATE DATASET_INFO #####
###########################

# TODO: move this into own function
# TODO: A more direct way to look up the gradients would be to get them from the
# configuration dict of the training run.
gradients: Dict[str, List[str]] = {}
for train_options in train_options_list:
for key in train_options["targets"].keys():
# look inside training sets and find gradients
for train_dataset in train_datasets:
if key in train_dataset[0].keys():
gradients[key] = train_dataset[0][key].block().gradients_list()
atomic_types = get_atomic_types(
train_datasets + train_datasets + validation_datasets
)

dataset_info = DatasetInfo(
length_unit=(
train_options_list[0]["systems"]["length_unit"]
if train_options_list[0]["systems"]["length_unit"] is not None
else ""
), # these units are guaranteed to be the same across all datasets
atomic_types=get_atomic_types(train_datasets + validation_datasets),
targets={
key: TargetInfo(
quantity=value["quantity"],
unit=(value["unit"] if value["unit"] is not None else ""),
per_atom=False, # TODO: read this from the config
gradients=gradients[key],
)
for train_options in train_options_list
for key, value in train_options["targets"].items()
},
length_unit=train_options_list[0]["systems"]["length_unit"],
atomic_types=atomic_types,
targets=target_infos,
)

###########################
Expand Down
9 changes: 5 additions & 4 deletions src/metatensor/models/experimental/alchemical_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
super().__init__()
self.hypers = model_hypers
self.dataset_info = dataset_info
self.atomic_types = sorted(dataset_info.atomic_types)

if len(dataset_info.targets) != 1:
raise ValueError("The AlchemicalModel only supports a single target")
Expand All @@ -49,7 +50,7 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
}

self.alchemical_model = AlchemicalModelUpstream(
unique_numbers=self.dataset_info.atomic_types,
unique_numbers=self.atomic_types,
**self.hypers["soap"],
**self.hypers["bpnn"],
)
Expand Down Expand Up @@ -158,7 +159,7 @@ def export(self) -> MetatensorAtomisticModel:

capabilities = ModelCapabilities(
outputs=self.outputs,
atomic_types=self.dataset_info.atomic_types,
atomic_types=self.atomic_types,
interaction_range=self.hypers["soap"]["cutoff"],
length_unit=self.dataset_info.length_unit,
supported_devices=self.__supported_devices__,
Expand All @@ -170,14 +171,14 @@ def export(self) -> MetatensorAtomisticModel:
def set_composition_weights(
self,
input_composition_weights: torch.Tensor,
species: List[int],
atomic_types: List[int],
) -> None:
"""Set the composition weights for a given output."""
input_composition_weights = input_composition_weights.to(
dtype=self.alchemical_model.composition_weights.dtype,
device=self.alchemical_model.composition_weights.device,
)
index = [self.dataset_info.atomic_types.index(s) for s in species]
index = [self.atomic_types.index(s) for s in atomic_types]
composition_weights = input_composition_weights[:, index]
self.alchemical_model.set_composition_weights(composition_weights)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_to(device, dtype):

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
atomic_types={1, 6, 7, 8},
targets={
"energy": TargetInfo(
quantity="energy",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_prediction_subset_elements():

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
atomic_types={1, 6, 7, 8},
targets={
"energy": TargetInfo(
quantity="energy",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_rotational_invariance():

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
atomic_types={1, 6, 7, 8},
targets={
"energy": TargetInfo(
quantity="energy",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_regression_init():

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
atomic_types={1, 6, 7, 8},
targets={
"mtm::U0": TargetInfo(
quantity="energy",
Expand Down Expand Up @@ -84,25 +84,19 @@ def test_regression_train():
"read_from": DATASET_PATH,
"file_format": ".xyz",
"key": "U0",
"unit": "eV",
"forces": False,
"stress": False,
"virial": False,
}
}
targets = read_targets(OmegaConf.create(conf))
targets, target_info_dict = read_targets(OmegaConf.create(conf))
dataset = Dataset({"system": systems, "mtm::U0": targets["mtm::U0"]})

hypers = DEFAULT_HYPERS.copy()

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets={
"mtm::U0": TargetInfo(
quantity="energy",
unit="eV",
),
},
length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=target_info_dict
)
model = AlchemicalModel(MODEL_HYPERS, dataset_info)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_alchemical_model_inference():

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=unique_numbers,
atomic_types=set(unique_numbers),
targets={
"energy": TargetInfo(
quantity="energy",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_torchscript():

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
atomic_types={1, 6, 7, 8},
targets={
"energy": TargetInfo(
quantity="energy",
Expand All @@ -29,7 +29,7 @@ def test_torchscript_save_load():

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
atomic_types={1, 6, 7, 8},
targets={
"energy": TargetInfo(
quantity="energy",
Expand Down
25 changes: 19 additions & 6 deletions src/metatensor/models/experimental/alchemical_model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ...utils.data import (
CombinedDataLoader,
Dataset,
TargetInfoDict,
check_datasets,
collate_fn,
get_all_targets,
Expand Down Expand Up @@ -96,23 +97,25 @@ def train(
f"Target {target_name} in the model's new capabilities is not "
"present in any of the training datasets."
)
composition_weights, species = calculate_composition_weights(
composition_weights, composition_types = calculate_composition_weights(
train_datasets_with_target, target_name
)
model.set_composition_weights(composition_weights.unsqueeze(0), species)
model.set_composition_weights(
composition_weights.unsqueeze(0), composition_types
)

# Remove the composition from the datasets:
train_datasets = [
remove_composition_from_dataset(
train_datasets[0],
model.dataset_info.atomic_types,
model.atomic_types,
model.alchemical_model.composition_weights.squeeze(0),
)
]
validation_datasets = [
remove_composition_from_dataset(
validation_datasets[0],
model.dataset_info.atomic_types,
model.atomic_types,
model.alchemical_model.composition_weights.squeeze(0),
)
]
Expand Down Expand Up @@ -212,7 +215,12 @@ def train(
predictions = evaluate_model(
model,
systems,
{key: model.dataset_info.targets[key] for key in targets.keys()},
TargetInfoDict(
**{
key: model.dataset_info.targets[key]
for key in targets.keys()
}
),
is_training=True,
)

Expand Down Expand Up @@ -242,7 +250,12 @@ def train(
predictions = evaluate_model(
model,
systems,
{key: model.dataset_info.targets[key] for key in targets.keys()},
TargetInfoDict(
**{
key: model.dataset_info.targets[key]
for key in targets.keys()
}
),
is_training=False,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_average_number_of_neighbors(

def remove_composition_from_dataset(
dataset: Union[Dataset, torch.utils.data.Subset],
all_species: List[int],
atomic_types: List[int],
composition_weights: torch.Tensor,
) -> List[Union[Dataset, torch.utils.data.Subset]]:
"""Remove the composition from the dataset.
Expand All @@ -84,8 +84,8 @@ def remove_composition_from_dataset(
system = dataset[i]["system"]
property = dataset[i][property_name]
numbers = system.types
composition = torch.bincount(numbers, minlength=max(all_species) + 1)
composition = composition[all_species].to(
composition = torch.bincount(numbers, minlength=max(atomic_types) + 1)
composition = composition[atomic_types].to(
device=composition_weights.device, dtype=composition_weights.dtype
)
property = metatensor.torch.subtract(
Expand Down
Loading

0 comments on commit bcfd69b

Please sign in to comment.