Skip to content

Commit

Permalink
Fix alchemical tests
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Jul 15, 2024
1 parent 08114b4 commit 5abec02
Showing 1 changed file with 8 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import random

import numpy as np
import pytest
import torch
from metatensor.torch.atomistic import ModelEvaluationOptions
from omegaconf import OmegaConf
Expand Down Expand Up @@ -36,7 +35,6 @@ def test_regression_init():
length_unit="Angstrom", atomic_types={1, 6, 7, 8}, targets=targets
)
model = AlchemicalModel(MODEL_HYPERS, dataset_info)
model.to(dtype=torch.float64)

# Predict on the first five systems
systems = read_systems(DATASET_PATH)[:5]
Expand All @@ -52,6 +50,7 @@ def test_regression_init():

exported = model.export()

systems = [system.to(dtype=torch.float32) for system in systems]
output = exported(systems, evaluation_options, check_consistency=True)

expected_output = torch.tensor(
Expand All @@ -62,21 +61,19 @@ def test_regression_init():
[-13.758152008057],
[-2.430717945099],
],
dtype=torch.float64,
)

# if you need to change the hardcoded values:
# torch.set_printoptions(precision=12)
# print(output["mtt::U0"].block().values)
torch.set_printoptions(precision=12)
print(output["mtt::U0"].block().values)

torch.testing.assert_close(
output["mtt::U0"].block().values,
expected_output,
)


@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_regression_train(dtype):
def test_regression_train():
"""Perform a regression test on the model when
trained for 2 epoch on a small dataset"""

Expand Down Expand Up @@ -113,7 +110,7 @@ def test_regression_train(dtype):
trainer = Trainer(hypers["training"])
trainer.train(
model=model,
dtype=dtype,
dtype=torch.float32,
devices=[torch.device("cpu")],
train_datasets=[dataset],
val_datasets=[dataset],
Expand All @@ -128,7 +125,7 @@ def test_regression_train(dtype):

exported = model.export()

systems = [system.to(dtype=dtype) for system in systems]
systems = [system.to(dtype=torch.float32) for system in systems]
output = exported(systems[:5], evaluation_options, check_consistency=True)

expected_output = torch.tensor(
Expand All @@ -139,12 +136,11 @@ def test_regression_train(dtype):
[-77.038444519043],
[-92.812789916992],
],
dtype=dtype,
)

# if you need to change the hardcoded values:
# torch.set_printoptions(precision=12)
# print(output["mtt::U0"].block().values)
torch.set_printoptions(precision=12)
print(output["mtt::U0"].block().values)

torch.testing.assert_close(
output["mtt::U0"].block().values,
Expand Down

0 comments on commit 5abec02

Please sign in to comment.