Skip to content

Commit

Permalink
Enable training
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Nov 30, 2023
1 parent d0e51be commit 85c8e13
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 14 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,6 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# Models
*.pt
50 changes: 40 additions & 10 deletions src/metatensor_models/soap_bpnn/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import rascaline.torch
import torch
import yaml
from metatensor_models.utils.data import Dataset, collate_fn
from metatensor_models.utils.data.readers import read_structures, read_targets

from metatensor_models.soap_bpnn import SoapBPNN
from metatensor_models.soap_bpnn import SoapBPNN, train


torch.random.manual_seed(0)
Expand All @@ -22,13 +24,11 @@ def test_regression_init():
[rascaline.torch.systems_to_torch(structure) for structure in structures]
)
expected_output = torch.tensor(
[
[0.278998736968],
[0.233572279098],
[0.011664706094],
[0.104852198342],
[0.059145453418],
],
[[ 0.051100484235],
[ 0.226915388550],
[-0.069549073530],
[-0.218989772242],
[-0.042997152257]],
dtype=torch.float64,
)

Expand All @@ -37,5 +37,35 @@ def test_regression_init():

def test_regression_train():
"""Perform a regression test on the model when
trained for 2 epoch trained on a small dataset"""
# TODO: Implement this test
trained for 2 epoch on a small dataset"""

all_species = [1, 6, 7, 8, 9]
hypers = yaml.safe_load(open("../default.yml", "r"))
hypers["epochs"] = 2
hypers["batch_size"] = 5
soap_bpnn = SoapBPNN(all_species, hypers).to(torch.float64)

structures = read_structures("data/qm9_reduced_100.xyz")
targets = read_targets("data/qm9_reduced_100.xyz", "U0")

dataset = Dataset(structures, targets)

hypers_training = hypers["training"]
hypers_training["num_epochs"] = 2
train(soap_bpnn, dataset, hypers_training)

output = soap_bpnn(structures[:5])
expected_output = torch.tensor(
[[-1.182792209483],
[-0.836589440867],
[-0.740011448717],
[-0.896406914741],
[-0.666903846884]],
dtype=torch.float64,
)

assert torch.allclose(output["energy"].block().values, expected_output)




12 changes: 8 additions & 4 deletions src/metatensor_models/soap_bpnn/train.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from ..utils.data import collate_fn

import logging

import torch


def loss_function(predicted, target):
return torch.sum((predicted.block.values - target.block.values) ** 2)
return torch.sum((predicted.block().values - target.block().values) ** 2)


def train(model, train_dataset, hypers):
Expand All @@ -13,21 +15,23 @@ def train(model, train_dataset, hypers):
dataset=train_dataset,
batch_size=hypers["batch_size"],
shuffle=True,
collate_fn=collate_fn,
)

# Create an optimizer:
optimizer = torch.optim.Adam(model.parameters(), lr=hypers["learning_rate"])

# Train the model:
for epoch in range(hypers["epochs"]):
for epoch in range(hypers["num_epochs"]):
if epoch % hypers["log_interval"] == 0:
logging.info(f"Epoch {epoch}")
if epoch % hypers["checkpoint_interval"] == 0:
torch.save(model.state_dict(), f"model-{epoch}.pt")
for batch in train_dataloader:
optimizer.zero_grad()
predicted = model(batch)
loss = loss_function(predicted, batch)
structures, targets = batch
predicted = model(structures)
loss = loss_function(predicted["energy"], targets["U0"])
loss.backward()
optimizer.step()

Expand Down

0 comments on commit 85c8e13

Please sign in to comment.