Skip to content

Commit

Permalink
Fix composition test
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Nov 12, 2024
1 parent 2a52859 commit 13da2b8
Showing 1 changed file with 29 additions and 16 deletions.
45 changes: 29 additions & 16 deletions tests/utils/test_additive.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from metatrain.utils.additive import ZBL, CompositionModel, remove_additive
from metatrain.utils.data import Dataset, DatasetInfo
from metatrain.utils.data.readers import read_systems, read_targets
from metatrain.utils.data.target_info import get_energy_target_info
from metatrain.utils.data.target_info import (
get_energy_target_info,
get_generic_target_info,
)
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
Expand Down Expand Up @@ -369,23 +372,33 @@ def test_composition_model_missing_types():

def test_composition_model_wrong_target():
"""
Test the error when a non-energy is fed to the composition model.
Test the error when a non-scalar is fed to the composition model.
"""
composition_model = CompositionModel(
model_hypers={},
dataset_info=DatasetInfo(
length_unit="angstrom",
atomic_types=[1],
targets={
"force": get_generic_target_info(
{
"quantity": "force",
"unit": "",
"type": {"cartesian": {"rank": 1}},
"num_properties": 1,
"per_atom": True,
}
)
},
),
)
# This should do nothing, because the target is not scalar and it should be
# ignored by the composition model. The warning is due to the "empty" dataset
# not containing H (atomic type 1)
with pytest.warns(UserWarning, match="do not contain atomic types"):
composition_model.train_model([])

with pytest.raises(
ValueError,
match="only supports energy-like outputs",
):
CompositionModel(
model_hypers={},
dataset_info=DatasetInfo(
length_unit="angstrom",
atomic_types=[1],
targets={
"energy": get_energy_target_info({"quantity": "force", "unit": ""})
},
),
)
assert composition_model.weights.shape == (0, 1)


def test_zbl():
Expand Down

0 comments on commit 13da2b8

Please sign in to comment.