Skip to content

Commit

Permalink
Integrate with metatensor.torch.atomistic (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Jan 11, 2024
1 parent 3882faa commit 4d7f847
Show file tree
Hide file tree
Showing 14 changed files with 112 additions and 32 deletions.
4 changes: 2 additions & 2 deletions docs/src/dev-docs/utils/readers/structure.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ Structure Readers
#################

Parsers for obtaining information from structures. All readers return a :py:class:`list`
of :py:class:`rascaline.torch.system.System`. The mapping which reader is used for which
file type is stored in
of :py:class:`metatensor.torch.atomistic.System`. The mapping which reader is used for
which file type is stored in

.. autodata:: metatensor.models.utils.data.readers.structures.STRUCTURE_READERS

Expand Down
29 changes: 22 additions & 7 deletions src/metatensor/models/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import rascaline.torch
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch.atomistic import ModelCapabilities, System
from omegaconf import OmegaConf

from .. import ARCHITECTURE_CONFIG_PATH
Expand Down Expand Up @@ -93,11 +94,25 @@ def forward(self, features: TensorMap) -> TensorMap:

class Model(torch.nn.Module):
def __init__(
self, all_species: List[int], hypers: Dict = DEFAULT_MODEL_HYPERS
self, capabilities: ModelCapabilities, hypers: Dict = DEFAULT_MODEL_HYPERS
) -> None:
super().__init__()
self.name = ARCHITECTURE_NAME
self.all_species = all_species

# Check capabilities
if len(capabilities.outputs) > 1:
raise ValueError(
"SOAP-BPNN only supports a single output, "
"but multiple outputs were provided"
)
if next(iter(capabilities.outputs.values())).quantity != "energy":
raise ValueError(
"SOAP-BPNN only supports energy-like outputs, "
f"but {next(iter(capabilities.outputs.values())).quantity} was provided"
)

self.capabilities = capabilities
self.all_species = capabilities.species
self.hypers = hypers

# creates a composition weight tensor that can be directly indexed by species,
Expand All @@ -110,22 +125,22 @@ def __init__(
self.soap_calculator = rascaline.torch.SoapPowerSpectrum(**hypers["soap"])
hypers_bpnn = hypers["bpnn"]
hypers_bpnn["input_size"] = (
len(all_species) ** 2
len(self.all_species) ** 2
* hypers["soap"]["max_radial"] ** 2
* (hypers["soap"]["max_angular"] + 1)
)
hypers_bpnn["output_size"] = 1
self.bpnn = MLPMap(all_species, hypers_bpnn)
self.bpnn = MLPMap(self.all_species, hypers_bpnn)
self.neighbor_species_1_labels = Labels(
names=["species_neighbor_1"],
values=torch.tensor(all_species).reshape(-1, 1),
values=torch.tensor(self.all_species).reshape(-1, 1),
)
self.neighbor_species_2_labels = Labels(
names=["species_neighbor_2"],
values=torch.tensor(all_species).reshape(-1, 1),
values=torch.tensor(self.all_species).reshape(-1, 1),
)

def forward(self, systems: List[rascaline.torch.System]) -> Dict[str, TensorMap]:
def forward(self, systems: List[System]) -> Dict[str, TensorMap]:
soap_features = self.soap_calculator(systems)

device = soap_features.block(0).values.device
Expand Down
15 changes: 13 additions & 2 deletions src/metatensor/models/soap_bpnn/tests/test_functionality.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ase
import rascaline.torch
import torch
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput

from metatensor.models.soap_bpnn import DEFAULT_HYPERS, Model

Expand All @@ -9,8 +10,18 @@ def test_prediction_subset():
"""Tests that the model can predict on a subset
of the elements it was trained on."""

all_species = [1, 6, 7, 8]
soap_bpnn = Model(all_species, DEFAULT_HYPERS["model"]).to(torch.float64)
capabilities = ModelCapabilities(
length_unit="Angstrom",
species=[1, 6, 7, 8],
outputs={
"energy": ModelOutput(
quantity="energy",
unit="eV",
)
},
)

soap_bpnn = Model(capabilities, DEFAULT_HYPERS["model"]).to(torch.float64)

structure = ase.Atoms("O2", positions=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]])
soap_bpnn([rascaline.torch.systems_to_torch(structure)])
14 changes: 12 additions & 2 deletions src/metatensor/models/soap_bpnn/tests/test_invariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ase.io
import rascaline.torch
import torch
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput

from metatensor.models.soap_bpnn import DEFAULT_HYPERS, Model

Expand All @@ -12,8 +13,17 @@
def test_rotational_invariance():
"""Tests that the model is rotationally invariant."""

all_species = [1, 6, 7, 8]
soap_bpnn = Model(all_species, DEFAULT_HYPERS["model"]).to(torch.float64)
capabilities = ModelCapabilities(
length_unit="Angstrom",
species=[1, 6, 7, 8],
outputs={
"energy": ModelOutput(
quantity="energy",
unit="eV",
)
},
)
soap_bpnn = Model(capabilities, DEFAULT_HYPERS["model"]).to(torch.float64)

structure = ase.io.read(DATASET_PATH)
original_structure = copy.deepcopy(structure)
Expand Down
14 changes: 12 additions & 2 deletions src/metatensor/models/soap_bpnn/tests/test_regression.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ase.io
import rascaline.torch
import torch
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput

from metatensor.models.soap_bpnn import DEFAULT_HYPERS, Model, train
from metatensor.models.utils.data import Dataset
Expand All @@ -15,8 +16,17 @@
def test_regression_init():
"""Perform a regression test on the model at initialization"""

all_species = [1, 6, 7, 8]
soap_bpnn = Model(all_species, DEFAULT_HYPERS["model"]).to(torch.float64)
capabilities = ModelCapabilities(
length_unit="Angstrom",
species=[1, 6, 7, 8],
outputs={
"energy": ModelOutput(
quantity="energy",
unit="eV",
)
},
)
soap_bpnn = Model(capabilities, DEFAULT_HYPERS["model"]).to(torch.float64)

# Predict on the first fivestructures
structures = ase.io.read(DATASET_PATH, ":5")
Expand Down
14 changes: 12 additions & 2 deletions src/metatensor/models/soap_bpnn/tests/test_torchscript.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import torch
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput

from metatensor.models.soap_bpnn import DEFAULT_HYPERS, Model


def test_torchscript():
"""Tests that the model can be jitted."""

all_species = [1, 6, 7, 8]
soap_bpnn = Model(all_species, DEFAULT_HYPERS["model"]).to(torch.float64)
capabilities = ModelCapabilities(
length_unit="Angstrom",
species=[1, 6, 7, 8],
outputs={
"energy": ModelOutput(
quantity="energy",
unit="eV",
)
},
)
soap_bpnn = Model(capabilities, DEFAULT_HYPERS["model"]).to(torch.float64)
torch.jit.script(soap_bpnn)
17 changes: 15 additions & 2 deletions src/metatensor/models/soap_bpnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path

import torch
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput

from ..utils.composition import calculate_composition_weights
from ..utils.data import collate_fn
Expand All @@ -17,10 +18,21 @@ def loss_function(predicted, target):


def train(train_dataset, hypers=DEFAULT_HYPERS, output_dir="."):
# Calculate and set the composition weights:
# Set the model's capabilities:
model_capabilities = ModelCapabilities(
length_unit="Angstrom",
species=train_dataset.all_species,
outputs={
"U0": ModelOutput(
quantity="energy",
unit="eV",
)
},
)

# Create the model:
model = Model(
all_species=train_dataset.all_species,
capabilities=model_capabilities,
hypers=hypers["model"],
)

Expand All @@ -32,6 +44,7 @@ def train(train_dataset, hypers=DEFAULT_HYPERS, output_dir="."):
else:
target = list(train_dataset.targets.keys())[0]

# Calculate and set the composition weights:
composition_weights = calculate_composition_weights(train_dataset, target)
model.set_composition_weights(composition_weights)

Expand Down
12 changes: 5 additions & 7 deletions src/metatensor/models/utils/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
from typing import Dict, List

import metatensor.torch
import rascaline.torch
import torch
from metatensor.torch import Labels, TensorMap
from metatensor.torch.atomistic import System


class Dataset(torch.utils.data.Dataset):
def __init__(
self, structures: List[rascaline.torch.System], targets: Dict[str, TensorMap]
):
def __init__(self, structures: List[System], targets: Dict[str, TensorMap]):
"""
Creates a dataset from a list of `rascaline.torch.System` objects
and a dictionary of targets where the keys are strings and the
values are `TensorMap` objects.
Creates a dataset from a list of `metatensor.torch.atomistic.System`
objects and a dictionary of targets where the keys are strings and
the values are `TensorMap` objects.
"""

for tensor_map in targets.values():
Expand Down
2 changes: 1 addition & 1 deletion src/metatensor/models/utils/data/readers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .structures import STRUCTURE_READERS
from .targets import TARGET_READERS

from rascaline.torch.system import System
from metatensor.torch.atomistic import System


def read_structures(filename: str, fileformat: Optional[str] = None) -> List[System]:
Expand Down
2 changes: 1 addition & 1 deletion src/metatensor/models/utils/data/writers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pathlib import Path
from metatensor.torch import TensorMap
from rascaline.torch.system import System
from metatensor.torch.atomistic import System

from .xyz import write_xyz

Expand Down
4 changes: 2 additions & 2 deletions src/metatensor/models/utils/model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def save_model(
"architecture_name": model.name,
"model_state_dict": model.state_dict(),
"model_hypers": model.hypers,
"all_species": model.all_species,
"model_capabilities": model.capabilities,
},
path,
)
Expand Down Expand Up @@ -47,7 +47,7 @@ def load_model(path: str) -> torch.nn.Module:

# Create the model
model = architecture.Model(
all_species=model_dict["all_species"], hypers=model_dict["model_hypers"]
capabilities=model_dict["model_capabilities"], hypers=model_dict["model_hypers"]
)

# Load the model weights
Expand Down
Binary file modified tests/resources/bpnn-model.pt
Binary file not shown.
3 changes: 2 additions & 1 deletion tests/utils/data/test_target_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import pytest
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from rascaline.torch.system import System, systems_to_torch
from metatensor.torch.atomistic import System
from rascaline.torch import systems_to_torch

from metatensor.models.utils.data.writers import write_predictions, write_xyz

Expand Down
14 changes: 13 additions & 1 deletion tests/utils/test_model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import metatensor.torch
import rascaline.torch
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput

from metatensor.models import soap_bpnn
from metatensor.models.utils.data import read_structures
Expand All @@ -15,7 +16,18 @@ def test_save_load_model(monkeypatch, tmp_path):
"""Test that saving and loading a model works and preserves its internal state."""
monkeypatch.chdir(tmp_path)

model = soap_bpnn.Model(all_species=[1, 6, 7, 8])
capabilities = ModelCapabilities(
length_unit="Angstrom",
species=[1, 6, 7, 8],
outputs={
"energy": ModelOutput(
quantity="energy",
unit="eV",
)
},
)

model = soap_bpnn.Model(capabilities)
structures = read_structures(RESOURCES_PATH / "qm9_reduced_100.xyz")

output_before_save = model(rascaline.torch.systems_to_torch(structures))
Expand Down

0 comments on commit 4d7f847

Please sign in to comment.