Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add example TensorMap layout to TargetInfo #370

Merged
merged 14 commits into from
Nov 6, 2024
Merged
38 changes: 38 additions & 0 deletions docs/src/dev-docs/dataset-information.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
Dataset Information
===================

When working with ``metatrain``, you will most likely need to interact with some core
classes which are responsible for storing some information about datasets. All these
classes belong to the ``metatrain.utils.data`` module which can be found in the
:ref:`data` section of the developer documentation.

These classes are:

- :py:class:`metatrain.utils.data.DatasetInfo`: This class is responsible for storing
information about a dataset. It contains the length unit used in the dataset, the
atomic types present, as well as information about the dataset's targets as a
``Dict[str, TargetInfo]`` object. The keys of this dictionary are the names of the
targets in the datasets (e.g., ``energy``, ``mtt::dipole``, etc.).

- :py:class:`metatrain.utils.data.TargetInfo`: This class is responsible for storing
information about a target in a dataset. It contains the target's physical quantity,
the unit in which the target is expressed, and the ``layout`` of the target. The
``layout`` is ``TensorMap`` object with zero samples which is used to exemplify
the metadata of each target.

At the moment, only three types of layouts are supported:

- scalar: This type of layout is used when the target is a scalar quantity. The
``layout`` ``TensorMap`` object corresponding to a scalar must have one
``TensorBlock`` and no ``components``.
- Cartesian tensor: This type of layout is used when the target is a Cartesian tensor.
The ``layout`` ``TensorMap`` object corresponding to a Cartesian tensor must have
one ``TensorBlock`` and as many ``components`` as the tensor's rank. These
components are named ``xyz`` for a tensor of rank 1 and ``xyz_1``, ``xyz_2``, and
so on for higher ranks.
- Spherical tensor: This type of layout is used when the target is a spherical tensor.
The ``layout`` ``TensorMap`` object corresponding to a spherical tensor can have
multiple blocks corresponding to different irreps (irreducible representations) of
the target. The ``keys`` of the ``TensorMap`` object must have the ``o3_lambda``
and ``o3_sigma`` names, and each ``TensorBlock`` must have a single component named
``o3_mu``.
1 change: 1 addition & 0 deletions docs/src/dev-docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ module.
getting-started
architecture-life-cycle
new-architecture
dataset-information
cli/index
utils/index
2 changes: 2 additions & 0 deletions docs/src/dev-docs/utils/data/index.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
Data
====

.. _data:

API for handling data in ``metatrain``.

.. toctree::
Expand Down
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ authors = [{name = "metatrain developers"}]

dependencies = [
"ase < 3.23.0",
"metatensor-learn==0.2.3",
"metatensor-operations==0.2.3",
"metatensor-torch==0.5.5",
"metatensor-learn==0.3.0",
"metatensor-operations==0.3.0",
"metatensor-torch==0.6.0",
"jsonschema",
"omegaconf",
"python-hostlist",
Expand Down Expand Up @@ -59,7 +59,7 @@ build-backend = "setuptools.build_meta"

[project.optional-dependencies]
soap-bpnn = [
"rascaline-torch @ git+https://github.com/luthaf/rascaline@d181b28#subdirectory=python/rascaline-torch",
"rascaline-torch @ git+https://github.com/luthaf/rascaline@5326b6e#subdirectory=python/rascaline-torch",
]
alchemical-model = [
"torch_alchemical @ git+https://github.com/abmazitov/torch_alchemical.git@51ff519",
Expand All @@ -68,7 +68,7 @@ pet = [
"pet @ git+https://github.com/lab-cosmo/pet@7eddb2e",
]
gap = [
"rascaline-torch @ git+https://github.com/luthaf/rascaline@d181b28#subdirectory=python/rascaline-torch",
"rascaline-torch @ git+https://github.com/luthaf/rascaline@5326b6e#subdirectory=python/rascaline-torch",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why these changes appear here, when they should already be on master. Could you squash this PR & rebase?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I merged (saved me some time compared to rebasing). This should be gone

"skmatter",
"metatensor-learn",
"scipy",
Expand Down
74 changes: 65 additions & 9 deletions src/metatrain/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from ..utils.data import (
Dataset,
TargetInfo,
TargetInfoDict,
collate_fn,
read_systems,
read_targets,
Expand Down Expand Up @@ -159,7 +158,7 @@
def _eval_targets(
model: Union[MetatensorAtomisticModel, torch.jit._script.RecursiveScriptModule],
dataset: Union[Dataset, torch.utils.data.Subset],
options: TargetInfoDict,
options: Dict[str, TargetInfo],
return_predictions: bool,
check_consistency: bool = False,
) -> Optional[Dict[str, TensorMap]]:
Expand Down Expand Up @@ -335,17 +334,17 @@
# (but we don't/can't calculate RMSEs)
# TODO: allow the user to specify which outputs to evaluate
eval_targets = {}
eval_info_dict = TargetInfoDict()
gradients = ["positions"]
if all(not torch.all(system.cell == 0) for system in eval_systems):
# only add strain if all structures have cells
gradients.append("strain")
eval_info_dict = {}
do_strain_grad = all(
not torch.all(system.cell == 0) for system in eval_systems
)
layout = _get_energy_layout(do_strain_grad) # TODO: layout from the user
for key in model.capabilities().outputs.keys():
eval_info_dict[key] = TargetInfo(
quantity=model.capabilities().outputs[key].quantity,
unit=model.capabilities().outputs[key].unit,
per_atom=False, # TODO: allow the user to specify this
gradients=gradients,
# TODO: allow the user to specify whether per-atom or not
layout=layout,
)

eval_dataset = Dataset.from_dict({"system": eval_systems, **eval_targets})
Expand All @@ -368,3 +367,60 @@
capabilities=model.capabilities(),
predictions=predictions,
)


def _get_energy_layout(strain_gradient: bool) -> TensorMap:
block = TensorBlock(
# float64: otherwise metatensor can't serialize
values=torch.empty(0, 1, dtype=torch.float64),
samples=Labels(
names=["system"],
values=torch.empty((0, 1), dtype=torch.int32),
),
components=[],
properties=Labels.range("energy", 1),
)
position_gradient_block = TensorBlock(
# float64: otherwise metatensor can't serialize
values=torch.empty(0, 3, 1, dtype=torch.float64),
samples=Labels(
names=["sample", "atom"],
values=torch.empty((0, 2), dtype=torch.int32),
),
components=[
Labels(
names=["xyz"],
values=torch.arange(3, dtype=torch.int32).reshape(-1, 1),
),
],
properties=Labels.range("energy", 1),
)
block.add_gradient("positions", position_gradient_block)

if strain_gradient:
strain_gradient_block = TensorBlock(

Check warning on line 401 in src/metatrain/cli/eval.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/cli/eval.py#L401

Added line #L401 was not covered by tests
# float64: otherwise metatensor can't serialize
values=torch.empty(0, 3, 3, 1, dtype=torch.float64),
samples=Labels(
names=["sample", "atom"],
values=torch.empty((0, 2), dtype=torch.int32),
),
components=[
Labels(
names=["xyz_1"],
values=torch.arange(3, dtype=torch.int32).reshape(-1, 1),
),
Labels(
names=["xyz_2"],
values=torch.arange(3, dtype=torch.int32).reshape(-1, 1),
),
],
properties=Labels.range("energy", 1),
)
block.add_gradient("strain", strain_gradient_block)

Check warning on line 420 in src/metatrain/cli/eval.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/cli/eval.py#L420

Added line #L420 was not covered by tests

energy_layout = TensorMap(
keys=Labels.single(),
blocks=[block],
)
return energy_layout
18 changes: 12 additions & 6 deletions src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)
from ..utils.data import (
DatasetInfo,
TargetInfoDict,
TargetInfo,
get_atomic_types,
get_dataset,
get_stats,
Expand Down Expand Up @@ -227,11 +227,17 @@
options["training_set"] = expand_dataset_config(options["training_set"])

train_datasets = []
target_infos = TargetInfoDict()
for train_options in options["training_set"]:
dataset, target_info_dict = get_dataset(train_options)
target_info_dict: Dict[str, TargetInfo] = {}
for train_options in options["training_set"]: # loop over training sets
dataset, target_info_dict_single = get_dataset(train_options)
train_datasets.append(dataset)
target_infos.update(target_info_dict)
intersecting_keys = target_info_dict.keys() & target_info_dict_single.keys()
for key in intersecting_keys:
if target_info_dict[key] != target_info_dict_single[key]:
raise ValueError(
f"Target information for key {key} differs between training sets."

Check warning on line 238 in src/metatrain/cli/train.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/cli/train.py#L238

Added line #L238 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nicer to add what the difference is here, so it is clear to the users

)
target_info_dict.update(target_info_dict_single)

train_size = 1.0

Expand Down Expand Up @@ -320,7 +326,7 @@
dataset_info = DatasetInfo(
length_unit=options["training_set"][0]["systems"]["length_unit"],
atomic_types=atomic_types,
targets=target_infos,
targets=target_info_dict,
)

###########################
Expand Down
1 change: 1 addition & 0 deletions src/metatrain/experimental/alchemical_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def requested_neighbor_lists(self) -> List[NeighborListOptions]:
NeighborListOptions(
cutoff=self.cutoff,
full_list=True,
strict=True,
)
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from metatensor.torch.atomistic import ModelEvaluationOptions, System

from metatrain.experimental.alchemical_model import AlchemicalModel
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict
from metatrain.utils.data import DatasetInfo, TargetInfo
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from metatrain.utils.testing import energy_layout

from . import MODEL_HYPERS

Expand All @@ -22,7 +23,9 @@ def test_to(device, dtype):
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
)
model = AlchemicalModel(MODEL_HYPERS, dataset_info).to(dtype=dtype)
exported = model.export()
Expand All @@ -33,6 +36,7 @@ def test_to(device, dtype):
types=torch.tensor([6, 6]),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]),
cell=torch.zeros(3, 3),
pbc=torch.tensor([False, False, False]),
)
requested_neighbor_lists = get_requested_neighbor_lists(exported)
system = get_system_with_neighbor_lists(system, requested_neighbor_lists)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from metatensor.torch.atomistic import ModelEvaluationOptions, System

from metatrain.experimental.alchemical_model import AlchemicalModel
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict
from metatrain.utils.data import DatasetInfo, TargetInfo
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from metatrain.utils.testing import energy_layout

from . import MODEL_HYPERS

Expand All @@ -18,7 +19,9 @@ def test_prediction_subset_elements():
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
)

model = AlchemicalModel(MODEL_HYPERS, dataset_info)
Expand All @@ -27,6 +30,7 @@ def test_prediction_subset_elements():
types=torch.tensor([6, 6]),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]),
cell=torch.zeros(3, 3),
pbc=torch.tensor([False, False, False]),
)
requested_neighbor_lists = get_requested_neighbor_lists(model)
system = get_system_with_neighbor_lists(system, requested_neighbor_lists)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from metatensor.torch.atomistic import ModelEvaluationOptions, systems_to_torch

from metatrain.experimental.alchemical_model import AlchemicalModel
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict
from metatrain.utils.data import DatasetInfo, TargetInfo
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from metatrain.utils.testing import energy_layout

from . import DATASET_PATH, MODEL_HYPERS

Expand All @@ -20,7 +21,9 @@ def test_rotational_invariance():
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
)
model = AlchemicalModel(MODEL_HYPERS, dataset_info)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
read_systems,
read_targets,
)
from metatrain.utils.data.dataset import TargetInfoDict
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from metatrain.utils.testing import energy_layout

from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS

Expand All @@ -31,8 +31,8 @@
def test_regression_init():
"""Perform a regression test on the model at initialization"""

targets = TargetInfoDict()
targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV")
targets = {}
targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV", layout=energy_layout)

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from metatrain.experimental.alchemical_model.utils import (
systems_to_torch_alchemical_batch,
)
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict, read_systems
from metatrain.utils.data import DatasetInfo, TargetInfo, read_systems
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists
from metatrain.utils.testing import energy_layout

from . import MODEL_HYPERS, QM9_DATASET_PATH

Expand All @@ -29,6 +30,7 @@
nl_options = NeighborListOptions(
cutoff=5.0,
full_list=True,
strict=True,
)
systems = [get_system_with_neighbor_lists(system, [nl_options]) for system in systems]

Expand Down Expand Up @@ -70,7 +72,9 @@ def test_alchemical_model_inference():
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=unique_numbers,
targets=TargetInfoDict(energy=TargetInfo(quantity="energy", unit="eV")),
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
)

alchemical_model = AlchemicalModel(MODEL_HYPERS, dataset_info)
Expand Down
Loading
Loading