From 01bcf05814e3c75eb7ee18c32475450792665101 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 8 Dec 2023 11:21:22 +0100 Subject: [PATCH] First attempt --- src/metatensor/models/soap_bpnn/default.yml | 33 ++++++++++---------- src/metatensor/models/soap_bpnn/model.py | 3 ++ src/metatensor/models/soap_bpnn/train.py | 20 +++++++----- src/metatensor/models/utils/model_io.py | 34 +++++++++++++++++++++ tests/model_io/test_model_io.py | 6 ++++ 5 files changed, 73 insertions(+), 23 deletions(-) create mode 100644 src/metatensor/models/utils/model_io.py create mode 100644 tests/model_io/test_model_io.py diff --git a/src/metatensor/models/soap_bpnn/default.yml b/src/metatensor/models/soap_bpnn/default.yml index 5fa4d8c5b..cd76aedc3 100644 --- a/src/metatensor/models/soap_bpnn/default.yml +++ b/src/metatensor/models/soap_bpnn/default.yml @@ -1,21 +1,22 @@ # default hyperparameters for the SOAP-BPNN model -soap: - cutoff: 5.0 - max_radial: 8 - max_angular: 6 - atomic_gaussian_width: 0.3 - radial_basis: - Gto: {} - center_atom_weight: 1.0 - cutoff_function: - ShiftedCosine: - width: 1.0 - radial_scaling: - Willatt2018: - rate: 1.0 - scale: 2.0 - exponent: 7.0 +model: + soap: + cutoff: 5.0 + max_radial: 8 + max_angular: 6 + atomic_gaussian_width: 0.3 + radial_basis: + Gto: {} + center_atom_weight: 1.0 + cutoff_function: + ShiftedCosine: + width: 1.0 + radial_scaling: + Willatt2018: + rate: 1.0 + scale: 2.0 + exponent: 7.0 bpnn: num_hidden_layers: 2 diff --git a/src/metatensor/models/soap_bpnn/model.py b/src/metatensor/models/soap_bpnn/model.py index 9e23e5b47..4416fcd87 100644 --- a/src/metatensor/models/soap_bpnn/model.py +++ b/src/metatensor/models/soap_bpnn/model.py @@ -8,6 +8,9 @@ from ..utils.composition import apply_composition_contribution +ARCHITECTURE_NAME = "soap_bpnn" + + class MLPMap(torch.nn.Module): def __init__(self, all_species: List[int], hypers: dict) -> None: super().__init__() diff --git a/src/metatensor/models/soap_bpnn/train.py b/src/metatensor/models/soap_bpnn/train.py index 5b5215573..3bb85db10 100644 --- a/src/metatensor/models/soap_bpnn/train.py +++ b/src/metatensor/models/soap_bpnn/train.py @@ -2,8 +2,11 @@ import torch +from .model import ARCHITECTURE_NAME + from ..utils.composition import calculate_composition_weights from ..utils.data import collate_fn +from ..utils.model_io import save_model def loss_function(predicted, target): @@ -14,6 +17,9 @@ def loss_function(predicted, target): def train(model, train_dataset, hypers): + model_hypers = hypers["model"] + training_hypers = hypers["training"] + # Calculate and set the composition weights: composition_weights = calculate_composition_weights(train_dataset, "U0") model.set_composition_weights(composition_weights) @@ -21,20 +27,20 @@ def train(model, train_dataset, hypers): # Create a dataloader for the training dataset: train_dataloader = torch.utils.data.DataLoader( dataset=train_dataset, - batch_size=hypers["batch_size"], + batch_size=training_hypers["batch_size"], shuffle=True, collate_fn=collate_fn, ) # Create an optimizer: - optimizer = torch.optim.Adam(model.parameters(), lr=hypers["learning_rate"]) + optimizer = torch.optim.Adam(model.parameters(), lr=training_hypers["learning_rate"]) # Train the model: - for epoch in range(hypers["num_epochs"]): - if epoch % hypers["log_interval"] == 0: + for epoch in range(training_hypers["num_epochs"]): + if epoch % training_hypers["log_interval"] == 0: logger.info(f"Epoch {epoch}") - if epoch % hypers["checkpoint_interval"] == 0: - torch.save(model.state_dict(), f"model-{epoch}.pt") + if epoch % training_hypers["checkpoint_interval"] == 0: + save_model(ARCHITECTURE_NAME, model, model_hypers, model.all_species, f"model_{epoch}.pt") for batch in train_dataloader: optimizer.zero_grad() structures, targets = batch @@ -44,4 +50,4 @@ def train(model, train_dataset, hypers): optimizer.step() # Save the model: - torch.save(model.state_dict(), "model_final.pt") + save_model(ARCHITECTURE_NAME, model, model_hypers, model.all_species, f"model_{epoch}.pt") diff --git a/src/metatensor/models/utils/model_io.py b/src/metatensor/models/utils/model_io.py new file mode 100644 index 000000000..0ec22fff8 --- /dev/null +++ b/src/metatensor/models/utils/model_io.py @@ -0,0 +1,34 @@ +import torch + +from typing import Dict, List + + +def save_model(arch_name: str, model: torch.nn.Module, hypers: Dict, all_species: List[int], path: str) -> None: + """Saves a model to a file, along with all the metadata needed to load it. + + Parameters + ---------- + arch_name (str): The name of the architecture. + + model (torch.nn.Module): The model to save. + + hypers (Dict): The hyperparameters used to train the model. + + path (str): The path to the file. + """ + torch.save({"name": arch_name, "model": model.state_dict(), "hypers": hypers, "all_species": all_species}, path) + + +def load_model(path: str) -> torch.nn.Module: + """Loads a model from a file. + + Parameters + ---------- + path (str): The path to the file. + + Returns + ------- + torch.nn.Module: The loaded model. + """ + # TODO, possibly with hydra utilities? + pass diff --git a/tests/model_io/test_model_io.py b/tests/model_io/test_model_io.py new file mode 100644 index 000000000..08181de91 --- /dev/null +++ b/tests/model_io/test_model_io.py @@ -0,0 +1,6 @@ +from metatensor.models import soap_bpnn + + +def test_save_load_model(): + """Test that saving and loading a model works.""" + pass