Skip to content

Commit

Permalink
Merge branch 'main' into download-model
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Nov 8, 2024
2 parents dd8eba8 + eb34fed commit 91fb617
Show file tree
Hide file tree
Showing 53 changed files with 1,076 additions and 725 deletions.
38 changes: 38 additions & 0 deletions docs/src/dev-docs/dataset-information.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
Dataset Information
===================

When working with ``metatrain``, you will most likely need to interact with some core
classes which are responsible for storing some information about datasets. All these
classes belong to the ``metatrain.utils.data`` module which can be found in the
:ref:`data` section of the developer documentation.

These classes are:

- :py:class:`metatrain.utils.data.DatasetInfo`: This class is responsible for storing
information about a dataset. It contains the length unit used in the dataset, the
atomic types present, as well as information about the dataset's targets as a
``Dict[str, TargetInfo]`` object. The keys of this dictionary are the names of the
targets in the datasets (e.g., ``energy``, ``mtt::dipole``, etc.).

- :py:class:`metatrain.utils.data.TargetInfo`: This class is responsible for storing
information about a target in a dataset. It contains the target's physical quantity,
the unit in which the target is expressed, and the ``layout`` of the target. The
``layout`` is ``TensorMap`` object with zero samples which is used to exemplify
the metadata of each target.

At the moment, only three types of layouts are supported:

- scalar: This type of layout is used when the target is a scalar quantity. The
``layout`` ``TensorMap`` object corresponding to a scalar must have one
``TensorBlock`` and no ``components``.
- Cartesian tensor: This type of layout is used when the target is a Cartesian tensor.
The ``layout`` ``TensorMap`` object corresponding to a Cartesian tensor must have
one ``TensorBlock`` and as many ``components`` as the tensor's rank. These
components are named ``xyz`` for a tensor of rank 1 and ``xyz_1``, ``xyz_2``, and
so on for higher ranks.
- Spherical tensor: This type of layout is used when the target is a spherical tensor.
The ``layout`` ``TensorMap`` object corresponding to a spherical tensor can have
multiple blocks corresponding to different irreps (irreducible representations) of
the target. The ``keys`` of the ``TensorMap`` object must have the ``o3_lambda``
and ``o3_sigma`` names, and each ``TensorBlock`` must have a single component named
``o3_mu``.
1 change: 1 addition & 0 deletions docs/src/dev-docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ module.
getting-started
architecture-life-cycle
new-architecture
dataset-information
cli/index
utils/index
2 changes: 2 additions & 0 deletions docs/src/dev-docs/utils/data/index.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.. _data:

Data
====

Expand Down
74 changes: 65 additions & 9 deletions src/metatrain/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from ..utils.data import (
Dataset,
TargetInfo,
TargetInfoDict,
collate_fn,
read_systems,
read_targets,
Expand Down Expand Up @@ -161,7 +160,7 @@ def _concatenate_tensormaps(
def _eval_targets(
model: Union[MetatensorAtomisticModel, torch.jit._script.RecursiveScriptModule],
dataset: Union[Dataset, torch.utils.data.Subset],
options: TargetInfoDict,
options: Dict[str, TargetInfo],
return_predictions: bool,
check_consistency: bool = False,
) -> Optional[Dict[str, TensorMap]]:
Expand Down Expand Up @@ -337,17 +336,17 @@ def eval_model(
# (but we don't/can't calculate RMSEs)
# TODO: allow the user to specify which outputs to evaluate
eval_targets = {}
eval_info_dict = TargetInfoDict()
gradients = ["positions"]
if all(not torch.all(system.cell == 0) for system in eval_systems):
# only add strain if all structures have cells
gradients.append("strain")
eval_info_dict = {}
do_strain_grad = all(
not torch.all(system.cell == 0) for system in eval_systems
)
layout = _get_energy_layout(do_strain_grad) # TODO: layout from the user
for key in model.capabilities().outputs.keys():
eval_info_dict[key] = TargetInfo(
quantity=model.capabilities().outputs[key].quantity,
unit=model.capabilities().outputs[key].unit,
per_atom=False, # TODO: allow the user to specify this
gradients=gradients,
# TODO: allow the user to specify whether per-atom or not
layout=layout,
)

eval_dataset = Dataset.from_dict({"system": eval_systems, **eval_targets})
Expand All @@ -370,3 +369,60 @@ def eval_model(
capabilities=model.capabilities(),
predictions=predictions,
)


def _get_energy_layout(strain_gradient: bool) -> TensorMap:
block = TensorBlock(
# float64: otherwise metatensor can't serialize
values=torch.empty(0, 1, dtype=torch.float64),
samples=Labels(
names=["system"],
values=torch.empty((0, 1), dtype=torch.int32),
),
components=[],
properties=Labels.range("energy", 1),
)
position_gradient_block = TensorBlock(
# float64: otherwise metatensor can't serialize
values=torch.empty(0, 3, 1, dtype=torch.float64),
samples=Labels(
names=["sample", "atom"],
values=torch.empty((0, 2), dtype=torch.int32),
),
components=[
Labels(
names=["xyz"],
values=torch.arange(3, dtype=torch.int32).reshape(-1, 1),
),
],
properties=Labels.range("energy", 1),
)
block.add_gradient("positions", position_gradient_block)

if strain_gradient:
strain_gradient_block = TensorBlock(
# float64: otherwise metatensor can't serialize
values=torch.empty(0, 3, 3, 1, dtype=torch.float64),
samples=Labels(
names=["sample", "atom"],
values=torch.empty((0, 2), dtype=torch.int32),
),
components=[
Labels(
names=["xyz_1"],
values=torch.arange(3, dtype=torch.int32).reshape(-1, 1),
),
Labels(
names=["xyz_2"],
values=torch.arange(3, dtype=torch.int32).reshape(-1, 1),
),
],
properties=Labels.range("energy", 1),
)
block.add_gradient("strain", strain_gradient_block)

energy_layout = TensorMap(
keys=Labels.single(),
blocks=[block],
)
return energy_layout
19 changes: 13 additions & 6 deletions src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)
from ..utils.data import (
DatasetInfo,
TargetInfoDict,
TargetInfo,
get_atomic_types,
get_dataset,
get_stats,
Expand Down Expand Up @@ -227,11 +227,18 @@ def train_model(
options["training_set"] = expand_dataset_config(options["training_set"])

train_datasets = []
target_infos = TargetInfoDict()
for train_options in options["training_set"]:
dataset, target_info_dict = get_dataset(train_options)
target_info_dict: Dict[str, TargetInfo] = {}
for train_options in options["training_set"]: # loop over training sets
dataset, target_info_dict_single = get_dataset(train_options)
train_datasets.append(dataset)
target_infos.update(target_info_dict)
intersecting_keys = target_info_dict.keys() & target_info_dict_single.keys()
for key in intersecting_keys:
if target_info_dict[key] != target_info_dict_single[key]:
raise ValueError(
f"Target information for key {key} differs between training sets. "
f"Got {target_info_dict[key]} and {target_info_dict_single[key]}."
)
target_info_dict.update(target_info_dict_single)

train_size = 1.0

Expand Down Expand Up @@ -320,7 +327,7 @@ def train_model(
dataset_info = DatasetInfo(
length_unit=options["training_set"][0]["systems"]["length_unit"],
atomic_types=atomic_types,
targets=target_infos,
targets=target_info_dict,
)

###########################
Expand Down
55 changes: 28 additions & 27 deletions src/metatrain/experimental/alchemical_model/default-hypers.yaml
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
name: experimental.alchemical_model
architecture:
name: experimental.alchemical_model

model:
soap:
num_pseudo_species: 4
cutoff: 5.0
basis_cutoff_power_spectrum: 400
radial_basis_type: "physical"
basis_scale: 3.0
trainable_basis: true
normalize: true
contract_center_species: true
bpnn:
hidden_sizes: [32, 32]
output_size: 1
zbl: false
model:
soap:
num_pseudo_species: 4
cutoff: 5.0
basis_cutoff_power_spectrum: 400
radial_basis_type: "physical"
basis_scale: 3.0
trainable_basis: true
normalize: true
contract_center_species: true
bpnn:
hidden_sizes: [32, 32]
output_size: 1
zbl: false

training:
batch_size: 8
num_epochs: 100
learning_rate: 0.001
early_stopping_patience: 200
scheduler_patience: 100
scheduler_factor: 0.8
log_interval: 5
checkpoint_interval: 25
per_structure_targets: []
loss_weights: {}
log_mae: False
training:
batch_size: 8
num_epochs: 100
learning_rate: 0.001
early_stopping_patience: 200
scheduler_patience: 100
scheduler_factor: 0.8
log_interval: 5
checkpoint_interval: 25
per_structure_targets: []
loss_weights: {}
log_mae: False
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from metatensor.torch.atomistic import ModelEvaluationOptions, System

from metatrain.experimental.alchemical_model import AlchemicalModel
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict
from metatrain.utils.data import DatasetInfo, TargetInfo
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from metatrain.utils.testing import energy_layout

from . import MODEL_HYPERS

Expand All @@ -22,7 +23,9 @@ def test_to(device, dtype):
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
)
model = AlchemicalModel(MODEL_HYPERS, dataset_info).to(dtype=dtype)
exported = model.export()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from metatensor.torch.atomistic import ModelEvaluationOptions, System

from metatrain.experimental.alchemical_model import AlchemicalModel
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict
from metatrain.utils.data import DatasetInfo, TargetInfo
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from metatrain.utils.testing import energy_layout

from . import MODEL_HYPERS

Expand All @@ -18,7 +19,9 @@ def test_prediction_subset_elements():
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
)

model = AlchemicalModel(MODEL_HYPERS, dataset_info)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from metatensor.torch.atomistic import ModelEvaluationOptions, systems_to_torch

from metatrain.experimental.alchemical_model import AlchemicalModel
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict
from metatrain.utils.data import DatasetInfo, TargetInfo
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from metatrain.utils.testing import energy_layout

from . import DATASET_PATH, MODEL_HYPERS

Expand All @@ -20,7 +21,9 @@ def test_rotational_invariance():
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
)
model = AlchemicalModel(MODEL_HYPERS, dataset_info)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
read_systems,
read_targets,
)
from metatrain.utils.data.dataset import TargetInfoDict
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from metatrain.utils.testing import energy_layout

from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS

Expand All @@ -31,8 +31,8 @@
def test_regression_init():
"""Perform a regression test on the model at initialization"""

targets = TargetInfoDict()
targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV")
targets = {}
targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV", layout=energy_layout)

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from metatrain.experimental.alchemical_model.utils import (
systems_to_torch_alchemical_batch,
)
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict, read_systems
from metatrain.utils.data import DatasetInfo, TargetInfo, read_systems
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists
from metatrain.utils.testing import energy_layout

from . import MODEL_HYPERS, QM9_DATASET_PATH

Expand Down Expand Up @@ -71,7 +72,9 @@ def test_alchemical_model_inference():
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=unique_numbers,
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
)

alchemical_model = AlchemicalModel(MODEL_HYPERS, dataset_info)
Expand Down
Loading

0 comments on commit 91fb617

Please sign in to comment.