Skip to content

Commit

Permalink
Change fitting order of composition and ZBL
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Dec 15, 2024
1 parent d3bb736 commit 9277267
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 39 deletions.
2 changes: 1 addition & 1 deletion src/metatrain/experimental/gap/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def train(
# Calculate and set the composition weights:
logger.info("Calculating composition weights")
# model.additive_models[0] is the composition model
model.additive_models[0].train_model(train_datasets)
model.additive_models[0].train_model(train_datasets, model.additive_models[1:])

logger.info("Setting up data loaders")
if len(train_datasets[0][0][output_name].keys) > 1:
Expand Down
4 changes: 3 additions & 1 deletion src/metatrain/experimental/nanopet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def train(

logger.info("Calculating composition weights")
model.additive_models[0].train_model( # this is the composition model
train_datasets, self.hypers["fixed_composition_weights"]
train_datasets,
model.additive_models[1:],
self.hypers["fixed_composition_weights"],
)

if self.hypers["scale_targets"]:
Expand Down
4 changes: 3 additions & 1 deletion src/metatrain/experimental/soap_bpnn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def train(

logger.info("Calculating composition weights")
model.additive_models[0].train_model( # this is the composition model
train_datasets, self.hypers["fixed_composition_weights"]
train_datasets,
model.additive_models[1:],
self.hypers["fixed_composition_weights"],
)

if self.hypers["scale_targets"]:
Expand Down
65 changes: 40 additions & 25 deletions src/metatrain/utils/additive/composition.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import warnings
from typing import Dict, List, Optional, Union

import metatensor.torch
Expand All @@ -9,6 +8,8 @@

from ..data import Dataset, DatasetInfo, TargetInfo, get_all_targets, get_atomic_types
from ..jsonschema import validate
from ..transfer import systems_and_targets_to_device
from .remove import remove_additive


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -68,13 +69,16 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo):
def train_model(
self,
datasets: List[Union[Dataset, torch.utils.data.Subset]],
additive_models: List[torch.nn.Module],
fixed_weights: Optional[Dict[str, Dict[int, str]]] = None,
) -> None:
"""Train/fit the composition weights for the datasets.
:param datasets: Dataset(s) to calculate the composition weights for.
:param fixed_weights: Optional fixed weights to use for the composition model,
for one or more target quantities.
:param additive_models: Additive models to be removed from the targets
before calculating the statistics.
:raises ValueError: If the provided datasets contain unknown targets.
:raises ValueError: If the provided datasets contain unknown atomic types.
Expand All @@ -99,12 +103,13 @@ def train_model(

missing_types = sorted(set(self.atomic_types) - set(get_atomic_types(datasets)))
if missing_types:
warnings.warn(
logger.warning(
f"Provided `datasets` do not contain atomic types {missing_types}. "
f"Known types from initialization are {self.atomic_types}.",
stacklevel=2,
f"Known types from initialization are {self.atomic_types}."
)

device = self.weights.device

# Fill the weights for each "new" target (i.e. those that do not already
# have composition weights from a previous training run)
for target_key in self.new_targets:
Expand All @@ -130,24 +135,12 @@ def train_model(
datasets_with_target.append(dataset)
if len(datasets_with_target) == 0:
# this is a possibility when transfer learning
warnings.warn(
logger.warning(
f"Target {target_key} in the model's new capabilities is not "
"present in any of the training datasets.",
stacklevel=2,
"present in any of the training datasets."
)
continue

targets = torch.stack(
[
sample[target_key].block().values
for dataset in datasets_with_target
for sample in dataset
]
)

# remove component and property dimensions
targets = targets.squeeze(dim=(1, 2))

total_num_structures = sum(
[len(dataset) for dataset in datasets_with_target]
)
Expand All @@ -159,17 +152,39 @@ def train_model(
)

composition_features = torch.zeros(
(total_num_structures, len(self.atomic_types)), dtype=dtype
(total_num_structures, len(self.atomic_types)),
dtype=dtype,
device=device,
)
structure_index = 0
system_index = 0
targets_list = []

for dataset in datasets_with_target:
for sample in dataset:
structure = sample["system"]
systems = [sample["system"]]
targets = {target_key: sample[target_key]}
systems, targets = systems_and_targets_to_device(
systems, targets, device
)
for additive_model in additive_models:
target_info_dict = {
target_key: self.new_targets[target_key]
}
targets = remove_additive( # remove other additive models
systems,
targets,
additive_model,
target_info_dict,
)
for j, t in enumerate(self.atomic_types):
composition_features[structure_index, j] = torch.sum(
structure.types == t
composition_features[system_index, j] = torch.sum(
systems[0].types == t
)
structure_index += 1
system_index += 1
targets_list.append(targets[target_key].block().values)

all_targets = torch.concatenate(targets_list) # concatenate samples
all_targets = all_targets.squeeze(dim=-1) # remove property dimension

regularizer = 1e-20
while regularizer:
Expand All @@ -189,7 +204,7 @@ def train_model(
dtype=composition_features.dtype,
device=composition_features.device,
),
composition_features.T @ targets,
composition_features.T @ all_targets,
).to(self.weights.dtype)
)
break
Expand Down
6 changes: 2 additions & 4 deletions src/metatrain/utils/additive/zbl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import warnings
from typing import Dict, List, Optional

import metatensor.torch
Expand Down Expand Up @@ -89,10 +88,9 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo):
if ase_covalent_radius == 0.2:
# 0.2 seems to be the default value when the covalent radius
# is not known/available
warnings.warn(
logger.warning(
f"Covalent radius for element {t} is not available in ASE. "
"Using a default value of 0.2 Å.",
stacklevel=2,
"Using a default value of 0.2 Å."
)
self.covalent_radii[i] = ase_covalent_radius

Expand Down
14 changes: 7 additions & 7 deletions tests/utils/test_additive.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_composition_model_train():
),
)

composition_model.train_model(dataset)
composition_model.train_model(dataset, [])
assert composition_model.weights.shape[0] == 1
assert composition_model.weights.shape[1] == 2
assert composition_model.output_name_to_output_index == {"energy": 0}
Expand All @@ -99,7 +99,7 @@ def test_composition_model_train():
composition_model.weights, torch.tensor([[2.0, 1.0]], dtype=torch.float64)
)

composition_model.train_model([dataset])
composition_model.train_model([dataset], [])
assert composition_model.weights.shape[0] == 1
assert composition_model.weights.shape[1] == 2
assert composition_model.output_name_to_output_index == {"energy": 0}
Expand All @@ -108,7 +108,7 @@ def test_composition_model_train():
composition_model.weights, torch.tensor([[2.0, 1.0]], dtype=torch.float64)
)

composition_model.train_model([dataset, dataset, dataset])
composition_model.train_model([dataset, dataset, dataset], [])
assert composition_model.weights.shape[0] == 1
assert composition_model.weights.shape[1] == 2
assert composition_model.output_name_to_output_index == {"energy": 0}
Expand Down Expand Up @@ -152,7 +152,7 @@ def test_composition_model_predict():
),
)

composition_model.train_model(dataset)
composition_model.train_model(dataset, [])

# per_atom = False
output = composition_model(
Expand Down Expand Up @@ -258,7 +258,7 @@ def test_remove_additive():
targets=target_info,
),
)
composition_model.train_model(dataset)
composition_model.train_model(dataset, [])

# concatenate all targets
targets["mtt::U0"] = metatensor.torch.join(targets["mtt::U0"], axis="samples")
Expand Down Expand Up @@ -345,7 +345,7 @@ def test_composition_model_missing_types():
ValueError,
match="unknown atomic types",
):
composition_model.train_model(dataset)
composition_model.train_model(dataset, [])

composition_model = CompositionModel(
model_hypers={},
Expand All @@ -359,7 +359,7 @@ def test_composition_model_missing_types():
UserWarning,
match="do not contain atomic types",
):
composition_model.train_model(dataset)
composition_model.train_model(dataset, [])


def test_composition_model_wrong_target():
Expand Down

0 comments on commit 9277267

Please sign in to comment.