Skip to content

Commit

Permalink
Make sure mean of the LLPR ensembles is consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Dec 15, 2024
1 parent 9277267 commit aded37d
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 6 deletions.
19 changes: 19 additions & 0 deletions src/metatrain/utils/llpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,25 @@ def forward(
ll_features.block().values,
ensemble_weights,
)

# since we know the exact mean of the ensemble from the model's prediction,
# it should be mathematically correct to use it to re-center the ensemble.
# Besides making sure that the average is always correct (so that results
# will always be consistent between LLPR ensembles and the original model),
# this also takes care of additive contributions that are not present in the
# last layer, which can be composition, short-range models, a bias in the
# last layer, etc.
original_name = (
name.replace("_ensemble", "").replace("aux::", "")
if name.replace("_ensemble", "").replace("aux::", "") in outputs
else name.replace("_ensemble", "").replace("mtt::aux::", "")
)
ensemble_values = (
ensemble_values
- ensemble_values.mean(dim=1, keepdim=True)
+ return_dict[original_name].block().values
)

property_name = "energy" if name == "energy_ensemble" else "ensemble_member"
ensemble = TensorMap(
keys=Labels(
Expand Down
12 changes: 6 additions & 6 deletions tests/utils/test_additive.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from pathlib import Path

import metatensor.torch
Expand Down Expand Up @@ -272,9 +273,9 @@ def test_remove_additive():
assert std_after < 100.0 * std_before


def test_composition_model_missing_types():
def test_composition_model_missing_types(caplog):
"""
Test the error when there are too many or too types in the dataset
Test the error when there are too many types in the dataset
compared to those declared at initialization.
"""

Expand Down Expand Up @@ -355,11 +356,10 @@ def test_composition_model_missing_types():
targets={"energy": get_energy_target_info({"unit": "eV"})},
),
)
with pytest.warns(
UserWarning,
match="do not contain atomic types",
):
# need to capture the warning from the logger
with caplog.at_level(logging.WARNING):
composition_model.train_model(dataset, [])
assert "do not contain atomic types" in caplog.text


def test_composition_model_wrong_target():
Expand Down
9 changes: 9 additions & 0 deletions tests/utils/test_llpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,20 @@ def test_llpr_covariance_as_pseudo_hessian(tmpdir):
assert "mtt::aux::energy_uncertainty" in outputs
assert "energy_ensemble" in outputs

predictions = outputs["energy"].block().values
analytical_uncertainty = outputs["mtt::aux::energy_uncertainty"].block().values
ensemble_mean = torch.mean(
outputs["energy_ensemble"].block().values, dim=1, keepdim=True
)
ensemble_uncertainty = torch.var(
outputs["energy_ensemble"].block().values, dim=1, keepdim=True
)

print(predictions)
print(ensemble_mean)
print(predictions - ensemble_mean)

torch.testing.assert_close(predictions, ensemble_mean, rtol=5e-3, atol=0.0)
torch.testing.assert_close(
analytical_uncertainty, ensemble_uncertainty, rtol=5e-3, atol=0.0
)

0 comments on commit aded37d

Please sign in to comment.