Skip to content

Commit

Permalink
First attempt
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Dec 8, 2023
1 parent 09c58cd commit 01bcf05
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 23 deletions.
33 changes: 17 additions & 16 deletions src/metatensor/models/soap_bpnn/default.yml
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
# default hyperparameters for the SOAP-BPNN model

soap:
cutoff: 5.0
max_radial: 8
max_angular: 6
atomic_gaussian_width: 0.3
radial_basis:
Gto: {}
center_atom_weight: 1.0
cutoff_function:
ShiftedCosine:
width: 1.0
radial_scaling:
Willatt2018:
rate: 1.0
scale: 2.0
exponent: 7.0
model:
soap:
cutoff: 5.0
max_radial: 8
max_angular: 6
atomic_gaussian_width: 0.3
radial_basis:
Gto: {}
center_atom_weight: 1.0
cutoff_function:
ShiftedCosine:
width: 1.0
radial_scaling:
Willatt2018:
rate: 1.0
scale: 2.0
exponent: 7.0

bpnn:
num_hidden_layers: 2
Expand Down
3 changes: 3 additions & 0 deletions src/metatensor/models/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from ..utils.composition import apply_composition_contribution


ARCHITECTURE_NAME = "soap_bpnn"


class MLPMap(torch.nn.Module):
def __init__(self, all_species: List[int], hypers: dict) -> None:
super().__init__()
Expand Down
20 changes: 13 additions & 7 deletions src/metatensor/models/soap_bpnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

import torch

from .model import ARCHITECTURE_NAME

from ..utils.composition import calculate_composition_weights
from ..utils.data import collate_fn
from ..utils.model_io import save_model


def loss_function(predicted, target):
Expand All @@ -14,27 +17,30 @@ def loss_function(predicted, target):


def train(model, train_dataset, hypers):
model_hypers = hypers["model"]
training_hypers = hypers["training"]

# Calculate and set the composition weights:
composition_weights = calculate_composition_weights(train_dataset, "U0")
model.set_composition_weights(composition_weights)

# Create a dataloader for the training dataset:
train_dataloader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=hypers["batch_size"],
batch_size=training_hypers["batch_size"],
shuffle=True,
collate_fn=collate_fn,
)

# Create an optimizer:
optimizer = torch.optim.Adam(model.parameters(), lr=hypers["learning_rate"])
optimizer = torch.optim.Adam(model.parameters(), lr=training_hypers["learning_rate"])

# Train the model:
for epoch in range(hypers["num_epochs"]):
if epoch % hypers["log_interval"] == 0:
for epoch in range(training_hypers["num_epochs"]):
if epoch % training_hypers["log_interval"] == 0:
logger.info(f"Epoch {epoch}")
if epoch % hypers["checkpoint_interval"] == 0:
torch.save(model.state_dict(), f"model-{epoch}.pt")
if epoch % training_hypers["checkpoint_interval"] == 0:
save_model(ARCHITECTURE_NAME, model, model_hypers, model.all_species, f"model_{epoch}.pt")
for batch in train_dataloader:
optimizer.zero_grad()
structures, targets = batch
Expand All @@ -44,4 +50,4 @@ def train(model, train_dataset, hypers):
optimizer.step()

# Save the model:
torch.save(model.state_dict(), "model_final.pt")
save_model(ARCHITECTURE_NAME, model, model_hypers, model.all_species, f"model_{epoch}.pt")
34 changes: 34 additions & 0 deletions src/metatensor/models/utils/model_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch

from typing import Dict, List


def save_model(arch_name: str, model: torch.nn.Module, hypers: Dict, all_species: List[int], path: str) -> None:
"""Saves a model to a file, along with all the metadata needed to load it.
Parameters
----------
arch_name (str): The name of the architecture.
model (torch.nn.Module): The model to save.
hypers (Dict): The hyperparameters used to train the model.
path (str): The path to the file.
"""
torch.save({"name": arch_name, "model": model.state_dict(), "hypers": hypers, "all_species": all_species}, path)


def load_model(path: str) -> torch.nn.Module:
"""Loads a model from a file.
Parameters
----------
path (str): The path to the file.
Returns
-------
torch.nn.Module: The loaded model.
"""
# TODO, possibly with hydra utilities?
pass
6 changes: 6 additions & 0 deletions tests/model_io/test_model_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from metatensor.models import soap_bpnn


def test_save_load_model():
"""Test that saving and loading a model works."""
pass

0 comments on commit 01bcf05

Please sign in to comment.