Skip to content

Commit

Permalink
Add gradient calculators (#26)
Browse files Browse the repository at this point in the history
* Add gradient calculator
* Temporary losses
* Forces and stresses
* Support multiple model outputs in SOAP-BPNN
  • Loading branch information
frostedoyster authored Jan 13, 2024
1 parent 4d7f847 commit 51df872
Show file tree
Hide file tree
Showing 16 changed files with 1,541 additions and 76 deletions.
4 changes: 4 additions & 0 deletions src/metatensor/models/cli/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,9 @@ def eval_model(model: str, structures: str, output: str = "output.xyz") -> None:

loaded_model = load_model(model)
structure_list = read_structures(structures)

# since the second argument is missing,
# this calculates all the available properties:
predictions = loaded_model(structure_list)

write_predictions(output, predictions, structure_list)
183 changes: 135 additions & 48 deletions src/metatensor/models/soap_bpnn/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Dict, List, Optional

import metatensor.torch
import rascaline.torch
Expand Down Expand Up @@ -35,29 +35,23 @@ def __init__(self, all_species: List[int], hypers: dict) -> None:
# Build a neural network for each species
nns_per_species = []
for _ in all_species:
module_list = [
torch.nn.Linear(hypers["input_size"], hypers["num_neurons_per_layer"]),
torch.nn.SiLU(),
]
module_list: List[torch.nn.Module] = []
for _ in range(hypers["num_hidden_layers"]):
module_list.append(
torch.nn.Linear(
hypers["num_neurons_per_layer"], hypers["num_neurons_per_layer"]
if len(module_list) == 0:
module_list.append(
torch.nn.Linear(
hypers["input_size"], hypers["num_neurons_per_layer"]
)
)
)
module_list.append(torch.nn.SiLU())

# If there are no hidden layers, the number of inputs
# for the last layer is the input size
n_inputs_last_layer = (
hypers["num_neurons_per_layer"]
if hypers["num_hidden_layers"] > 0
else hypers["input_size"]
)
else:
module_list.append(
torch.nn.Linear(
hypers["num_neurons_per_layer"],
hypers["num_neurons_per_layer"],
)
)
module_list.append(self.activation_function)

module_list.append(
torch.nn.Linear(n_inputs_last_layer, hypers["output_size"])
)
nns_per_species.append(torch.nn.Sequential(*module_list))

# Create a module dict to store the neural networks
Expand All @@ -75,9 +69,7 @@ def forward(self, features: TensorMap) -> TensorMap:
new_blocks: List[TensorBlock] = []
for species_str, network in self.layers.items():
species = int(species_str)
if species not in present_blocks:
pass # continue is not accepted by torchscript here
else:
if species in present_blocks:
block = features.block({"species_center": species})
output_values = network(block.values)
new_blocks.append(
Expand All @@ -92,6 +84,48 @@ def forward(self, features: TensorMap) -> TensorMap:
return TensorMap(keys=features.keys, blocks=new_blocks)


class LinearMap(torch.nn.Module):
def __init__(self, all_species: List[int], n_inputs: int) -> None:
super().__init__()

# Build a neural network for each species
layer_per_species = []
for _ in all_species:
layer_per_species.append(torch.nn.Linear(n_inputs, 1))

# Create a module dict to store the neural networks
self.layers = torch.nn.ModuleDict(
{
str(species): layer
for species, layer in zip(all_species, layer_per_species)
}
)

def forward(self, features: TensorMap) -> TensorMap:
# Create a list of the blocks that are present in the features:
present_blocks = [
int(features.keys.entry(i).values.item())
for i in range(features.keys.values.shape[0])
]

new_blocks: List[TensorBlock] = []
for species_str, layer in self.layers.items():
species = int(species_str)
if species in present_blocks:
block = features.block({"species_center": species})
output_values = layer(block.values)
new_blocks.append(
TensorBlock(
values=output_values,
samples=block.samples,
components=block.components,
properties=Labels.single(),
)
)

return TensorMap(keys=features.keys, blocks=new_blocks)


class Model(torch.nn.Module):
def __init__(
self, capabilities: ModelCapabilities, hypers: Dict = DEFAULT_MODEL_HYPERS
Expand All @@ -100,16 +134,18 @@ def __init__(
self.name = ARCHITECTURE_NAME

# Check capabilities
if len(capabilities.outputs) > 1:
raise ValueError(
"SOAP-BPNN only supports a single output, "
"but multiple outputs were provided"
)
if next(iter(capabilities.outputs.values())).quantity != "energy":
raise ValueError(
"SOAP-BPNN only supports energy-like outputs, "
f"but {next(iter(capabilities.outputs.values())).quantity} was provided"
)
for output in capabilities.outputs.values():
if output.quantity != "energy":
raise ValueError(
"SOAP-BPNN only supports energy-like outputs, "
f"but a {next(iter(capabilities.outputs.values())).quantity} "
"was provided"
)
if output.per_atom:
raise ValueError(
"SOAP-BPNN only supports per-structure outputs, "
"but a per-atom output was provided"
)

self.capabilities = capabilities
self.all_species = capabilities.species
Expand All @@ -118,9 +154,16 @@ def __init__(
# creates a composition weight tensor that can be directly indexed by species,
# this can be left as a tensor of zero or set from the outside using
# set_composition_weights (recommended for better accuracy)
n_outputs = len(capabilities.outputs)
self.register_buffer(
"composition_weights", torch.zeros(max(self.all_species) + 1)
"composition_weights", torch.zeros((n_outputs, max(self.all_species) + 1))
)
# buffers cannot be indexed by strings (torchscript), so we create a single
# tensor for all output. Due to this, we need to slice the tensor when we use
# it and use the output name to select the correct slice via a dictionary
self.output_to_index = {
output_name: i for i, output_name in enumerate(capabilities.outputs.keys())
}

self.soap_calculator = rascaline.torch.SoapPowerSpectrum(**hypers["soap"])
hypers_bpnn = hypers["bpnn"]
Expand All @@ -129,7 +172,7 @@ def __init__(
* hypers["soap"]["max_radial"] ** 2
* (hypers["soap"]["max_angular"] + 1)
)
hypers_bpnn["output_size"] = 1

self.bpnn = MLPMap(self.all_species, hypers_bpnn)
self.neighbor_species_1_labels = Labels(
names=["species_neighbor_1"],
Expand All @@ -140,7 +183,31 @@ def __init__(
values=torch.tensor(self.all_species).reshape(-1, 1),
)

def forward(self, systems: List[System]) -> Dict[str, TensorMap]:
if hypers_bpnn["num_hidden_layers"] == 0:
n_inputs_last_layer = hypers_bpnn["input_size"]
else:
n_inputs_last_layer = hypers_bpnn["num_neurons_per_layer"]

self.last_layers = torch.nn.ModuleDict(
{
output_name: LinearMap(self.all_species, n_inputs_last_layer)
for output_name in capabilities.outputs.keys()
}
)

def forward(
self, systems: List[System], requested_outputs: Optional[List[str]] = None
) -> Dict[str, TensorMap]:
if requested_outputs is None: # default to all outputs
requested_outputs = list(self.capabilities.outputs.keys())

for requested_output in requested_outputs:
if requested_output not in self.capabilities.outputs.keys():
raise ValueError(
f"Requested output {requested_output} is not within "
"the model's capabilities."
)

soap_features = self.soap_calculator(systems)

device = soap_features.block(0).values.device
Expand All @@ -151,19 +218,39 @@ def forward(self, systems: List[System]) -> Dict[str, TensorMap]:
self.neighbor_species_2_labels.to(device)
)

atomic_energies = self.bpnn(soap_features)
atomic_energies = apply_composition_contribution(
atomic_energies, self.composition_weights
)
atomic_energies = atomic_energies.keys_to_samples("species_center")
hidden_features = self.bpnn(soap_features)

atomic_energies: Dict[str, metatensor.torch.TensorMap] = {}
for output_name, output_layer in self.last_layers.items():
if output_name in requested_outputs:
atomic_energies[output_name] = apply_composition_contribution(
output_layer(hidden_features),
self.composition_weights[self.output_to_index[output_name]],
)

# Sum the atomic energies coming from the BPNN to get the total energy
total_energies = metatensor.torch.sum_over_samples(
atomic_energies, ["center", "species_center"]
)
total_energies: Dict[str, metatensor.torch.TensorMap] = {}
for output_name, atomic_energy in atomic_energies.items():
atomic_energy = atomic_energy.keys_to_samples("species_center")
total_energies[output_name] = metatensor.torch.sum_over_samples(
atomic_energy, ["center", "species_center"]
)
# Change the energy label from _ to (0, 1):
total_energies[output_name] = metatensor.torch.TensorMap(
keys=Labels(
names=["lambda", "sigma"],
values=torch.tensor([[0, 1]]),
),
blocks=[total_energies[output_name].block()],
)

return {"energy": total_energies}
return total_energies

def set_composition_weights(self, input_composition_weights: torch.Tensor) -> None:
def set_composition_weights(
self, output_name: str, input_composition_weights: torch.Tensor
) -> None:
"""Set the composition weights for a given output."""
# all species that are not present retain their weight of zero
self.composition_weights[self.all_species] = input_composition_weights
self.composition_weights[self.output_to_index[output_name]][
self.all_species
] = input_composition_weights
12 changes: 7 additions & 5 deletions src/metatensor/models/soap_bpnn/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def test_regression_init():
[rascaline.torch.systems_to_torch(structure) for structure in structures]
)
expected_output = torch.tensor(
[
[[0.5021], [0.3809], [0.1849], [0.2126], [0.0920]],
],
[[-0.4615], [-0.4367], [-0.3004], [-0.2606], [-0.2380]],
dtype=torch.float64,
)

print(output["energy"].block().values)

assert torch.allclose(output["energy"].block().values, expected_output, rtol=1e-3)


Expand All @@ -61,8 +61,10 @@ def test_regression_train():
output = soap_bpnn(structures[:5])

expected_output = torch.tensor(
[[-40.5923], [-56.5135], [-76.4457], [-77.2500], [-93.1583]],
[[-39.9658], [-56.0888], [-76.1100], [-76.9461], [-93.0914]],
dtype=torch.float64,
)

assert torch.allclose(output["energy"].block().values, expected_output, rtol=1e-3)
print(output["U0"].block().values)

assert torch.allclose(output["U0"].block().values, expected_output, rtol=1e-3)
36 changes: 19 additions & 17 deletions src/metatensor/models/soap_bpnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,31 @@
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput

from ..utils.composition import calculate_composition_weights
from ..utils.compute_loss import compute_model_loss
from ..utils.data import collate_fn
from ..utils.loss import TensorMapDictLoss
from ..utils.model_io import save_model
from .model import DEFAULT_HYPERS, Model


logger = logging.getLogger(__name__)


def loss_function(predicted, target):
return torch.sum((predicted.block().values - target.block().values) ** 2)


def train(train_dataset, hypers=DEFAULT_HYPERS, output_dir="."):
if len(train_dataset.targets) > 1:
raise ValueError(
f"`train_dataset` contains {len(train_dataset.targets)} targets but we "
"currently only support a single target value!"
)
else:
target_name = list(train_dataset.targets.keys())[0]

# Set the model's capabilities:
model_capabilities = ModelCapabilities(
length_unit="Angstrom",
species=train_dataset.all_species,
outputs={
"U0": ModelOutput(
target_name: ModelOutput(
quantity="energy",
unit="eV",
)
Expand All @@ -36,17 +42,9 @@ def train(train_dataset, hypers=DEFAULT_HYPERS, output_dir="."):
hypers=hypers["model"],
)

if len(train_dataset.targets) > 1:
raise ValueError(
f"`train_dataset` contains {len(train_dataset.targets)} targets but we "
"currently only support a single target value!"
)
else:
target = list(train_dataset.targets.keys())[0]

# Calculate and set the composition weights:
composition_weights = calculate_composition_weights(train_dataset, target)
model.set_composition_weights(composition_weights)
composition_weights = calculate_composition_weights(train_dataset, target_name)
model.set_composition_weights(target_name, composition_weights)

hypers_training = hypers["training"]

Expand All @@ -58,6 +56,11 @@ def train(train_dataset, hypers=DEFAULT_HYPERS, output_dir="."):
collate_fn=collate_fn,
)

# Create a loss function:
loss_fn = TensorMapDictLoss(
{target_name: {"values": 1.0}},
)

# Create an optimizer:
optimizer = torch.optim.Adam(
model.parameters(), lr=hypers_training["learning_rate"]
Expand All @@ -75,8 +78,7 @@ def train(train_dataset, hypers=DEFAULT_HYPERS, output_dir="."):
for batch in train_dataloader:
optimizer.zero_grad()
structures, targets = batch
predicted = model(structures)
loss = loss_function(predicted["energy"], targets["U0"])
loss = compute_model_loss(loss_fn, model, structures, targets)
loss.backward()
optimizer.step()

Expand Down
Loading

0 comments on commit 51df872

Please sign in to comment.