Skip to content

Commit

Permalink
Add sample TensorMap to TargetInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Oct 26, 2024
1 parent 2c0a748 commit 80e0d50
Show file tree
Hide file tree
Showing 44 changed files with 682 additions and 574 deletions.
16 changes: 9 additions & 7 deletions src/metatrain/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
from metatensor.torch.atomistic import MetatensorAtomisticModel
from omegaconf import DictConfig, OmegaConf

from metatrain.utils.testing import energy_force_layout, energy_force_stress_layout

from ..utils.data import (
Dataset,
TargetInfo,
TargetInfoDict,
collate_fn,
read_systems,
read_targets,
Expand Down Expand Up @@ -159,7 +160,7 @@ def _concatenate_tensormaps(
def _eval_targets(
model: Union[MetatensorAtomisticModel, torch.jit._script.RecursiveScriptModule],
dataset: Union[Dataset, torch.utils.data.Subset],
options: TargetInfoDict,
options: Dict[str, TargetInfo],
return_predictions: bool,
check_consistency: bool = False,
) -> Optional[Dict[str, TensorMap]]:
Expand Down Expand Up @@ -335,17 +336,18 @@ def eval_model(
# (but we don't/can't calculate RMSEs)
# TODO: allow the user to specify which outputs to evaluate
eval_targets = {}
eval_info_dict = TargetInfoDict()
gradients = ["positions"]
eval_info_dict = {}
if all(not torch.all(system.cell == 0) for system in eval_systems):
# only add strain if all structures have cells
gradients.append("strain")
layout = energy_force_stress_layout

Check warning on line 342 in src/metatrain/cli/eval.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/cli/eval.py#L342

Added line #L342 was not covered by tests
else:
layout = energy_force_layout
for key in model.capabilities().outputs.keys():
eval_info_dict[key] = TargetInfo(
quantity=model.capabilities().outputs[key].quantity,
unit=model.capabilities().outputs[key].unit,
per_atom=False, # TODO: allow the user to specify this
gradients=gradients,
# TODO: allow the user to specify whether per-atom or not
layout=layout,
)

eval_dataset = Dataset.from_dict({"system": eval_systems, **eval_targets})
Expand Down
18 changes: 12 additions & 6 deletions src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)
from ..utils.data import (
DatasetInfo,
TargetInfoDict,
TargetInfo,
get_atomic_types,
get_dataset,
get_stats,
Expand Down Expand Up @@ -227,11 +227,17 @@ def train_model(
options["training_set"] = expand_dataset_config(options["training_set"])

train_datasets = []
target_infos = TargetInfoDict()
for train_options in options["training_set"]:
dataset, target_info_dict = get_dataset(train_options)
target_info_dict: Dict[str, TargetInfo] = {}
for train_options in options["training_set"]: # loop over training sets
dataset, target_info_dict_single = get_dataset(train_options)
train_datasets.append(dataset)
target_infos.update(target_info_dict)
intersecting_keys = target_info_dict.keys() & target_info_dict_single.keys()
for key in intersecting_keys:
if target_info_dict[key] != target_info_dict_single[key]:
raise ValueError(

Check warning on line 237 in src/metatrain/cli/train.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/cli/train.py#L237

Added line #L237 was not covered by tests
f"Target information for key {key} differs between training sets."
)
target_info_dict.update(target_info_dict_single)

train_size = 1.0

Expand Down Expand Up @@ -320,7 +326,7 @@ def train_model(
dataset_info = DatasetInfo(
length_unit=options["training_set"][0]["systems"]["length_unit"],
atomic_types=atomic_types,
targets=target_infos,
targets=target_info_dict,
)

###########################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from metatensor.torch.atomistic import ModelEvaluationOptions, System

from metatrain.experimental.alchemical_model import AlchemicalModel
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict
from metatrain.utils.data import DatasetInfo, TargetInfo
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from metatrain.utils.testing import energy_layout

from . import MODEL_HYPERS

Expand All @@ -22,7 +23,9 @@ def test_to(device, dtype):
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
)
model = AlchemicalModel(MODEL_HYPERS, dataset_info).to(dtype=dtype)
exported = model.export()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from metatensor.torch.atomistic import ModelEvaluationOptions, System

from metatrain.experimental.alchemical_model import AlchemicalModel
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict
from metatrain.utils.data import DatasetInfo, TargetInfo
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from metatrain.utils.testing import energy_layout

from . import MODEL_HYPERS

Expand All @@ -18,7 +19,9 @@ def test_prediction_subset_elements():
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
)

model = AlchemicalModel(MODEL_HYPERS, dataset_info)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from metatensor.torch.atomistic import ModelEvaluationOptions, systems_to_torch

from metatrain.experimental.alchemical_model import AlchemicalModel
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict
from metatrain.utils.data import DatasetInfo, TargetInfo
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from metatrain.utils.testing import energy_layout

from . import DATASET_PATH, MODEL_HYPERS

Expand All @@ -20,7 +21,9 @@ def test_rotational_invariance():
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
)
model = AlchemicalModel(MODEL_HYPERS, dataset_info)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
read_systems,
read_targets,
)
from metatrain.utils.data.dataset import TargetInfoDict
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from metatrain.utils.testing import energy_layout

from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS

Expand All @@ -31,8 +31,8 @@
def test_regression_init():
"""Perform a regression test on the model at initialization"""

targets = TargetInfoDict()
targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV")
targets = {}
targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV", layout=energy_layout)

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from metatrain.experimental.alchemical_model.utils import (
systems_to_torch_alchemical_batch,
)
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict, read_systems
from metatrain.utils.data import DatasetInfo, TargetInfo, read_systems
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists
from metatrain.utils.testing import energy_layout

from . import MODEL_HYPERS, QM9_DATASET_PATH

Expand Down Expand Up @@ -70,7 +71,9 @@ def test_alchemical_model_inference():
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=unique_numbers,
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
)

alchemical_model = AlchemicalModel(MODEL_HYPERS, dataset_info)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch

from metatrain.experimental.alchemical_model import AlchemicalModel
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict
from metatrain.utils.data import DatasetInfo, TargetInfo
from metatrain.utils.testing import energy_layout

from . import MODEL_HYPERS

Expand All @@ -12,7 +13,9 @@ def test_torchscript():
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
)

model = AlchemicalModel(MODEL_HYPERS, dataset_info)
Expand All @@ -25,7 +28,9 @@ def test_torchscript_save_load():
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
)
model = AlchemicalModel(MODEL_HYPERS, dataset_info)
torch.jit.save(
Expand Down
15 changes: 2 additions & 13 deletions src/metatrain/experimental/alchemical_model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from ...utils.data import (
CombinedDataLoader,
Dataset,
TargetInfoDict,
check_datasets,
collate_fn,
get_all_targets,
Expand Down Expand Up @@ -247,12 +246,7 @@ def train(
predictions = evaluate_model(
model,
systems,
TargetInfoDict(
**{
key: model.dataset_info.targets[key]
for key in targets.keys()
}
),
{key: model.dataset_info.targets[key] for key in targets.keys()},
is_training=True,
)

Expand Down Expand Up @@ -295,12 +289,7 @@ def train(
predictions = evaluate_model(
model,
systems,
TargetInfoDict(
**{
key: model.dataset_info.targets[key]
for key in targets.keys()
}
),
{key: model.dataset_info.targets[key] for key in targets.keys()},
is_training=False,
)

Expand Down
10 changes: 6 additions & 4 deletions src/metatrain/experimental/gap/tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
Dataset,
DatasetInfo,
TargetInfo,
TargetInfoDict,
read_systems,
read_targets,
)
from metatrain.utils.testing import energy_force_layout

from . import DATASET_ETHANOL_PATH, DEFAULT_HYPERS

Expand Down Expand Up @@ -61,9 +61,11 @@ def test_ethanol_regression_train_and_invariance():
hypers = copy.deepcopy(DEFAULT_HYPERS)
hypers["model"]["krr"]["num_sparse_points"] = 30

target_info_dict = TargetInfoDict(
energy=TargetInfo(quantity="energy", unit="kcal/mol", gradients=["positions"])
)
target_info_dict = {
"energy": TargetInfo(
quantity="energy", unit="kcal/mol", layout=energy_force_layout
)
}

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict
Expand Down
21 changes: 13 additions & 8 deletions src/metatrain/experimental/gap/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from omegaconf import OmegaConf

from metatrain.experimental.gap import GAP, Trainer
from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo, TargetInfoDict
from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo
from metatrain.utils.data.readers import read_systems, read_targets
from metatrain.utils.testing import energy_force_layout, energy_layout

from . import DATASET_ETHANOL_PATH, DATASET_PATH, DEFAULT_HYPERS

Expand All @@ -25,8 +26,8 @@

def test_regression_init():
"""Perform a regression test on the model at initialization"""
targets = TargetInfoDict()
targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV")
targets = {}
targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV", layout=energy_layout)

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets
Expand Down Expand Up @@ -57,8 +58,10 @@ def test_regression_train_and_invariance():
targets, _ = read_targets(OmegaConf.create(conf))
dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})

target_info_dict = TargetInfoDict()
target_info_dict["mtt::U0"] = TargetInfo(quantity="energy", unit="eV")
target_info_dict = {}
target_info_dict["mtt::U0"] = TargetInfo(
quantity="energy", unit="eV", layout=energy_layout
)

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict
Expand Down Expand Up @@ -138,9 +141,11 @@ def test_ethanol_regression_train_and_invariance():
hypers = copy.deepcopy(DEFAULT_HYPERS)
hypers["model"]["krr"]["num_sparse_points"] = 900

target_info_dict = TargetInfoDict(
energy=TargetInfo(quantity="energy", unit="kcal/mol", gradients=["positions"])
)
target_info_dict = {
"energy": TargetInfo(
quantity="energy", unit="kcal/mol", layout=energy_force_layout
)
}

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict
Expand Down
15 changes: 8 additions & 7 deletions src/metatrain/experimental/gap/tests/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from omegaconf import OmegaConf

from metatrain.experimental.gap import GAP, Trainer
from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo, TargetInfoDict
from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo
from metatrain.utils.data.readers import read_systems, read_targets
from metatrain.utils.testing import energy_layout

from . import DATASET_PATH, DEFAULT_HYPERS

Expand All @@ -13,8 +14,10 @@

def test_torchscript():
"""Tests that the model can be jitted."""
target_info_dict = TargetInfoDict()
target_info_dict["mtt::U0"] = TargetInfo(quantity="energy", unit="eV")
target_info_dict = {}
target_info_dict["mtt::U0"] = TargetInfo(
quantity="energy", unit="eV", layout=energy_layout
)

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict
Expand All @@ -34,8 +37,6 @@ def test_torchscript():
targets, _ = read_targets(OmegaConf.create(conf))
systems = read_systems(DATASET_PATH)

# for system in systems:
# system.types = torch.ones(len(system.types), dtype=torch.int32)
dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})

hypers = DEFAULT_HYPERS.copy()
Expand Down Expand Up @@ -64,8 +65,8 @@ def test_torchscript():

def test_torchscript_save():
"""Tests that the model can be jitted and saved."""
targets = TargetInfoDict()
targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV")
targets = {}
targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV", layout=energy_layout)

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets
Expand Down
Loading

0 comments on commit 80e0d50

Please sign in to comment.