From 85c8e13bacab851dfb445b50f6778f318c9282f6 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 30 Nov 2023 17:53:36 +0100 Subject: [PATCH] Enable training --- .gitignore | 3 ++ .../soap_bpnn/tests/test_regression.py | 50 +++++++++++++++---- src/metatensor_models/soap_bpnn/train.py | 12 +++-- 3 files changed, 51 insertions(+), 14 deletions(-) diff --git a/.gitignore b/.gitignore index 68bc17f9f..c017f06e5 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,6 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +# Models +*.pt diff --git a/src/metatensor_models/soap_bpnn/tests/test_regression.py b/src/metatensor_models/soap_bpnn/tests/test_regression.py index 79a25e669..94ae26717 100644 --- a/src/metatensor_models/soap_bpnn/tests/test_regression.py +++ b/src/metatensor_models/soap_bpnn/tests/test_regression.py @@ -2,8 +2,10 @@ import rascaline.torch import torch import yaml +from metatensor_models.utils.data import Dataset, collate_fn +from metatensor_models.utils.data.readers import read_structures, read_targets -from metatensor_models.soap_bpnn import SoapBPNN +from metatensor_models.soap_bpnn import SoapBPNN, train torch.random.manual_seed(0) @@ -22,13 +24,11 @@ def test_regression_init(): [rascaline.torch.systems_to_torch(structure) for structure in structures] ) expected_output = torch.tensor( - [ - [0.278998736968], - [0.233572279098], - [0.011664706094], - [0.104852198342], - [0.059145453418], - ], + [[ 0.051100484235], + [ 0.226915388550], + [-0.069549073530], + [-0.218989772242], + [-0.042997152257]], dtype=torch.float64, ) @@ -37,5 +37,35 @@ def test_regression_init(): def test_regression_train(): """Perform a regression test on the model when - trained for 2 epoch trained on a small dataset""" - # TODO: Implement this test + trained for 2 epoch on a small dataset""" + + all_species = [1, 6, 7, 8, 9] + hypers = yaml.safe_load(open("../default.yml", "r")) + hypers["epochs"] = 2 + hypers["batch_size"] = 5 + soap_bpnn = SoapBPNN(all_species, hypers).to(torch.float64) + + structures = read_structures("data/qm9_reduced_100.xyz") + targets = read_targets("data/qm9_reduced_100.xyz", "U0") + + dataset = Dataset(structures, targets) + + hypers_training = hypers["training"] + hypers_training["num_epochs"] = 2 + train(soap_bpnn, dataset, hypers_training) + + output = soap_bpnn(structures[:5]) + expected_output = torch.tensor( + [[-1.182792209483], + [-0.836589440867], + [-0.740011448717], + [-0.896406914741], + [-0.666903846884]], + dtype=torch.float64, + ) + + assert torch.allclose(output["energy"].block().values, expected_output) + + + + diff --git a/src/metatensor_models/soap_bpnn/train.py b/src/metatensor_models/soap_bpnn/train.py index 6386bf0da..cb6aeef26 100644 --- a/src/metatensor_models/soap_bpnn/train.py +++ b/src/metatensor_models/soap_bpnn/train.py @@ -1,10 +1,12 @@ +from ..utils.data import collate_fn + import logging import torch def loss_function(predicted, target): - return torch.sum((predicted.block.values - target.block.values) ** 2) + return torch.sum((predicted.block().values - target.block().values) ** 2) def train(model, train_dataset, hypers): @@ -13,21 +15,23 @@ def train(model, train_dataset, hypers): dataset=train_dataset, batch_size=hypers["batch_size"], shuffle=True, + collate_fn=collate_fn, ) # Create an optimizer: optimizer = torch.optim.Adam(model.parameters(), lr=hypers["learning_rate"]) # Train the model: - for epoch in range(hypers["epochs"]): + for epoch in range(hypers["num_epochs"]): if epoch % hypers["log_interval"] == 0: logging.info(f"Epoch {epoch}") if epoch % hypers["checkpoint_interval"] == 0: torch.save(model.state_dict(), f"model-{epoch}.pt") for batch in train_dataloader: optimizer.zero_grad() - predicted = model(batch) - loss = loss_function(predicted, batch) + structures, targets = batch + predicted = model(structures) + loss = loss_function(predicted["energy"], targets["U0"]) loss.backward() optimizer.step()