Skip to content

Commit

Permalink
Add sample TensorMap to TargetInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Oct 26, 2024
1 parent 2c0a748 commit b8cd9ba
Show file tree
Hide file tree
Showing 44 changed files with 738 additions and 576 deletions.
71 changes: 62 additions & 9 deletions src/metatrain/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from ..utils.data import (
Dataset,
TargetInfo,
TargetInfoDict,
collate_fn,
read_systems,
read_targets,
Expand Down Expand Up @@ -159,7 +158,7 @@ def _concatenate_tensormaps(
def _eval_targets(
model: Union[MetatensorAtomisticModel, torch.jit._script.RecursiveScriptModule],
dataset: Union[Dataset, torch.utils.data.Subset],
options: TargetInfoDict,
options: Dict[str, TargetInfo],
return_predictions: bool,
check_consistency: bool = False,
) -> Optional[Dict[str, TensorMap]]:
Expand Down Expand Up @@ -335,17 +334,17 @@ def eval_model(
# (but we don't/can't calculate RMSEs)
# TODO: allow the user to specify which outputs to evaluate
eval_targets = {}
eval_info_dict = TargetInfoDict()
gradients = ["positions"]
if all(not torch.all(system.cell == 0) for system in eval_systems):
# only add strain if all structures have cells
gradients.append("strain")
eval_info_dict = {}
do_strain_grad = all(
not torch.all(system.cell == 0) for system in eval_systems
)
layout = _get_energy_layout(do_strain_grad) # TODO: layout from the user
for key in model.capabilities().outputs.keys():
eval_info_dict[key] = TargetInfo(
quantity=model.capabilities().outputs[key].quantity,
unit=model.capabilities().outputs[key].unit,
per_atom=False, # TODO: allow the user to specify this
gradients=gradients,
# TODO: allow the user to specify whether per-atom or not
layout=layout,
)

eval_dataset = Dataset.from_dict({"system": eval_systems, **eval_targets})
Expand All @@ -368,3 +367,57 @@ def eval_model(
capabilities=model.capabilities(),
predictions=predictions,
)


def _get_energy_layout(strain_gradient: bool) -> TensorMap:
block = TensorBlock(
values=torch.empty(0, 1),
samples=Labels(
names=["system"],
values=torch.empty((0, 1), dtype=torch.int32),
),
components=[],
properties=Labels.range("energy", 1),
)
position_gradient_block = TensorBlock(
values=torch.empty(0, 3, 1),
samples=Labels(
names=["sample", "atom"],
values=torch.empty((0, 2), dtype=torch.int32),
),
components=[
Labels(
names=["xyz"],
values=torch.arange(3, dtype=torch.int32).reshape(-1, 1),
),
],
properties=Labels.range("energy", 1),
)
block.add_gradient("positions", position_gradient_block)

if strain_gradient:
strain_gradient_block = TensorBlock(

Check warning on line 399 in src/metatrain/cli/eval.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/cli/eval.py#L399

Added line #L399 was not covered by tests
values=torch.empty(0, 3, 3, 1),
samples=Labels(
names=["sample", "atom"],
values=torch.empty((0, 2), dtype=torch.int32),
),
components=[
Labels(
names=["xyz_1"],
values=torch.arange(3, dtype=torch.int32).reshape(-1, 1),
),
Labels(
names=["xyz_2"],
values=torch.arange(3, dtype=torch.int32).reshape(-1, 1),
),
],
properties=Labels.range("energy", 1),
)
block.add_gradient("strain", strain_gradient_block)

Check warning on line 417 in src/metatrain/cli/eval.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/cli/eval.py#L417

Added line #L417 was not covered by tests

energy_layout = TensorMap(
keys=Labels.single(),
blocks=[block],
)
return energy_layout
18 changes: 12 additions & 6 deletions src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)
from ..utils.data import (
DatasetInfo,
TargetInfoDict,
TargetInfo,
get_atomic_types,
get_dataset,
get_stats,
Expand Down Expand Up @@ -227,11 +227,17 @@ def train_model(
options["training_set"] = expand_dataset_config(options["training_set"])

train_datasets = []
target_infos = TargetInfoDict()
for train_options in options["training_set"]:
dataset, target_info_dict = get_dataset(train_options)
target_info_dict: Dict[str, TargetInfo] = {}
for train_options in options["training_set"]: # loop over training sets
dataset, target_info_dict_single = get_dataset(train_options)
train_datasets.append(dataset)
target_infos.update(target_info_dict)
intersecting_keys = target_info_dict.keys() & target_info_dict_single.keys()
for key in intersecting_keys:
if target_info_dict[key] != target_info_dict_single[key]:
raise ValueError(

Check warning on line 237 in src/metatrain/cli/train.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/cli/train.py#L237

Added line #L237 was not covered by tests
f"Target information for key {key} differs between training sets."
)
target_info_dict.update(target_info_dict_single)

train_size = 1.0

Expand Down Expand Up @@ -320,7 +326,7 @@ def train_model(
dataset_info = DatasetInfo(
length_unit=options["training_set"][0]["systems"]["length_unit"],
atomic_types=atomic_types,
targets=target_infos,
targets=target_info_dict,
)

###########################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from metatensor.torch.atomistic import ModelEvaluationOptions, System

from metatrain.experimental.alchemical_model import AlchemicalModel
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict
from metatrain.utils.data import DatasetInfo, TargetInfo
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from metatrain.utils.testing import energy_layout

from . import MODEL_HYPERS

Expand All @@ -22,7 +23,9 @@ def test_to(device, dtype):
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
)
model = AlchemicalModel(MODEL_HYPERS, dataset_info).to(dtype=dtype)
exported = model.export()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from metatensor.torch.atomistic import ModelEvaluationOptions, System

from metatrain.experimental.alchemical_model import AlchemicalModel
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict
from metatrain.utils.data import DatasetInfo, TargetInfo
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from metatrain.utils.testing import energy_layout

from . import MODEL_HYPERS

Expand All @@ -18,7 +19,9 @@ def test_prediction_subset_elements():
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
)

model = AlchemicalModel(MODEL_HYPERS, dataset_info)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from metatensor.torch.atomistic import ModelEvaluationOptions, systems_to_torch

from metatrain.experimental.alchemical_model import AlchemicalModel
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict
from metatrain.utils.data import DatasetInfo, TargetInfo
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from metatrain.utils.testing import energy_layout

from . import DATASET_PATH, MODEL_HYPERS

Expand All @@ -20,7 +21,9 @@ def test_rotational_invariance():
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
)
model = AlchemicalModel(MODEL_HYPERS, dataset_info)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
read_systems,
read_targets,
)
from metatrain.utils.data.dataset import TargetInfoDict
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from metatrain.utils.testing import energy_layout

from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS

Expand All @@ -31,8 +31,8 @@
def test_regression_init():
"""Perform a regression test on the model at initialization"""

targets = TargetInfoDict()
targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV")
targets = {}
targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV", layout=energy_layout)

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from metatrain.experimental.alchemical_model.utils import (
systems_to_torch_alchemical_batch,
)
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict, read_systems
from metatrain.utils.data import DatasetInfo, TargetInfo, read_systems
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists
from metatrain.utils.testing import energy_layout

from . import MODEL_HYPERS, QM9_DATASET_PATH

Expand Down Expand Up @@ -70,7 +71,9 @@ def test_alchemical_model_inference():
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=unique_numbers,
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
)

alchemical_model = AlchemicalModel(MODEL_HYPERS, dataset_info)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch

from metatrain.experimental.alchemical_model import AlchemicalModel
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict
from metatrain.utils.data import DatasetInfo, TargetInfo
from metatrain.utils.testing import energy_layout

from . import MODEL_HYPERS

Expand All @@ -12,7 +13,9 @@ def test_torchscript():
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
)

model = AlchemicalModel(MODEL_HYPERS, dataset_info)
Expand All @@ -25,7 +28,9 @@ def test_torchscript_save_load():
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
)
model = AlchemicalModel(MODEL_HYPERS, dataset_info)
torch.jit.save(
Expand Down
15 changes: 2 additions & 13 deletions src/metatrain/experimental/alchemical_model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from ...utils.data import (
CombinedDataLoader,
Dataset,
TargetInfoDict,
check_datasets,
collate_fn,
get_all_targets,
Expand Down Expand Up @@ -247,12 +246,7 @@ def train(
predictions = evaluate_model(
model,
systems,
TargetInfoDict(
**{
key: model.dataset_info.targets[key]
for key in targets.keys()
}
),
{key: model.dataset_info.targets[key] for key in targets.keys()},
is_training=True,
)

Expand Down Expand Up @@ -295,12 +289,7 @@ def train(
predictions = evaluate_model(
model,
systems,
TargetInfoDict(
**{
key: model.dataset_info.targets[key]
for key in targets.keys()
}
),
{key: model.dataset_info.targets[key] for key in targets.keys()},
is_training=False,
)

Expand Down
10 changes: 6 additions & 4 deletions src/metatrain/experimental/gap/tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
Dataset,
DatasetInfo,
TargetInfo,
TargetInfoDict,
read_systems,
read_targets,
)
from metatrain.utils.testing import energy_force_layout

from . import DATASET_ETHANOL_PATH, DEFAULT_HYPERS

Expand Down Expand Up @@ -61,9 +61,11 @@ def test_ethanol_regression_train_and_invariance():
hypers = copy.deepcopy(DEFAULT_HYPERS)
hypers["model"]["krr"]["num_sparse_points"] = 30

target_info_dict = TargetInfoDict(
energy=TargetInfo(quantity="energy", unit="kcal/mol", gradients=["positions"])
)
target_info_dict = {
"energy": TargetInfo(
quantity="energy", unit="kcal/mol", layout=energy_force_layout
)
}

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict
Expand Down
Loading

0 comments on commit b8cd9ba

Please sign in to comment.