Skip to content

Commit

Permalink
Some working tests
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Nov 30, 2023
1 parent 9df6409 commit 47d0474
Show file tree
Hide file tree
Showing 10 changed files with 1,282 additions and 31 deletions.
10 changes: 0 additions & 10 deletions src/metatensor_models/soap-bpnn/tests/test_invariance.py

This file was deleted.

7 changes: 0 additions & 7 deletions src/metatensor_models/soap-bpnn/tests/test_regression.py

This file was deleted.

2 changes: 2 additions & 0 deletions src/metatensor_models/soap_bpnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .model import SoapBPNN
from .train import train
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ soap:
max_radial: 8
max_angular: 6
atomic_gaussian_width: 0.3
radial_basis: Gto
radial_basis:
Gto: {}
center_atom_weight: 1.0
cutoff_function:
ShiftedCosine:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ def __init__(self, all_species: List[int], hypers: dict) -> None:
nns_per_species = []
for _ in all_species:
module_list = [
torch.nn.Linear(hypers["input_size"], hypers["num_neurons_per_layers"]),
torch.nn.Linear(hypers["input_size"], hypers["num_neurons_per_layer"]),
torch.nn.SiLU(),
]
for _ in range(hypers["num_hidden_layers"]):
module_list.append(torch.nn.Linear(hypers["num_neurons_per_layers"], hypers["num_neurons_per_layers"]))
module_list.append(torch.nn.Linear(hypers["num_neurons_per_layer"], hypers["num_neurons_per_layer"]))
module_list.append(torch.nn.SiLU())

# If there are no hidden layers, the number of inputs for the last layer is the input size
n_inputs_last_layer = hypers["num_neurons_per_layers"] if hypers["num_hidden_layers"] > 0 else hypers["input_size"]
n_inputs_last_layer = hypers["num_neurons_per_layer"] if hypers["num_hidden_layers"] > 0 else hypers["input_size"]

module_list.append(torch.nn.Linear(n_inputs_last_layer, hypers["output_size"]))
nns_per_species.append(torch.nn.Sequential(*module_list))
Expand All @@ -39,9 +39,14 @@ def __init__(self, all_species: List[int], hypers: dict) -> None:
})

def forward(self, features: TensorMap) -> TensorMap:

# Create a list of the blocks that are present in the features:
present_blocks = [int(key.values.item()) for key in features.keys]

new_blocks: List[TensorBlock] = []
for species_str, network in self.layers.items():
species = int(species_str)
if species not in present_blocks: continue
# Here, do we have to check that the species is actually present in the system?
block = features.block({"species_center": species})
output_values = network(block.values)
Expand All @@ -62,13 +67,13 @@ def forward(self, features: TensorMap) -> TensorMap:
class SoapBPNN(torch.nn.Module):
def __init__(self, all_species, hypers) -> None:
super().__init__()
self.soap_calculator = rascaline.torch.PowerSpectrum(
hypers["soap"]
self.soap_calculator = rascaline.torch.SoapPowerSpectrum(
**hypers["soap"]
)
hypers_bpnn = hypers["bpnn"]
hypers_bpnn["input_size"] = hypers["soap"]["max_radial"]**2 * (hypers["soap"]["max_angular"] + 1)
hypers_bpnn["input_size"] = len(all_species)**2 * hypers["soap"]["max_radial"]**2 * (hypers["soap"]["max_angular"] + 1)
hypers_bpnn["output_size"] = 1
self.bpnn = MLPMap(all_species, hypers["bpnn"])
self.bpnn = MLPMap(all_species, hypers_bpnn)
self.neighbor_species_1_labels = Labels(
names=["species_neighbor_1"],
values=torch.tensor(all_species).reshape(-1, 1)
Expand All @@ -90,8 +95,6 @@ def forward(self, systems: List[rascaline.torch.System]) -> Dict[str, TensorMap]
atomic_energies = atomic_energies.keys_to_samples("species_center")

# Sum the atomic energies coming from the BPNN to get the total energy
total_energies = metatensor.torch.sum(atomic_energies, ["center", "species_center"])
total_energies = metatensor.torch.sum_over_samples(atomic_energies, ["center", "species_center"])

return {"energy": total_energies}


1,201 changes: 1,201 additions & 0 deletions src/metatensor_models/soap_bpnn/tests/data/qm9_reduced_100.xyz

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
These 100 structures are the first 100 structures in the QM9 dataset.
They are used to test the model.
These 100 structures are the first 100 structures in the QM9 dataset; they are used to test the SOAP-BPNN model.
26 changes: 26 additions & 0 deletions src/metatensor_models/soap_bpnn/tests/test_invariance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
import ase.io
import rascaline.torch
import copy

import yaml


from metatensor_models.soap_bpnn import SoapBPNN


def test_rotational_invariance():
"""Tests that the model is rotationally invariant."""

all_species = [1, 6, 7, 8, 9]
hypers = yaml.safe_load(open("../default.yml", "r"))
soap_bpnn = SoapBPNN(all_species, hypers).to(torch.float64)

structure = ase.io.read("data/qm9_reduced_100.xyz")
original_structure = copy.deepcopy(structure)
structure.rotate(48, "y")

original_output = soap_bpnn([rascaline.torch.systems_to_torch(original_structure)])
rotated_output = soap_bpnn([rascaline.torch.systems_to_torch(structure)])

assert torch.allclose(original_output["energy"].block().values, rotated_output["energy"].block().values)
36 changes: 36 additions & 0 deletions src/metatensor_models/soap_bpnn/tests/test_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
torch.random.manual_seed(0)
import ase.io
import rascaline.torch
import copy

import yaml

from metatensor_models.soap_bpnn import SoapBPNN


def test_regression_init():
"""Perform a regression test on the model at initialization"""

all_species = [1, 6, 7, 8, 9]
hypers = yaml.safe_load(open("../default.yml", "r"))
soap_bpnn = SoapBPNN(all_species, hypers).to(torch.float64)

structures = ase.io.read("data/qm9_reduced_100.xyz", ":5")

output = soap_bpnn([rascaline.torch.systems_to_torch(structure) for structure in structures])
expected_output = torch.tensor(
[[ 0.051073328644],
[ 0.226986056644],
[-0.069471667587],
[-0.218890301061],
[-0.043132789197]],
dtype=torch.float64
)

assert torch.allclose(output["energy"].block().values, expected_output)


def test_regression_train():
"""Perform a regression test on the model when trained for 2 epoch trained on a small dataset"""
# TODO: Implement this test
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def loss_function(predicted, target):
return torch.sum((predicted.block.values - target.block.values)**2)


def trainer(model, train_dataset, hypers):
def train(model, train_dataset, hypers):

# Create a dataloader for the training dataset:
train_dataloader = torch.utils.data.DataLoader(
Expand Down

0 comments on commit 47d0474

Please sign in to comment.