Skip to content

Commit

Permalink
Exclude mtt::aux:: quantities from composition models
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Aug 18, 2024
1 parent fc71849 commit 9a7718b
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 11 deletions.
2 changes: 2 additions & 0 deletions src/metatrain/experimental/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ def forward(
# at evaluation, we also add the composition contributions
composition_contributions = self.composition_model(systems, outputs)
for name in return_dict:
if name.startswith("mtt::aux::"):
continue
return_dict[name] = metatensor.torch.add(
return_dict[name],
composition_contributions[name],
Expand Down
6 changes: 6 additions & 0 deletions src/metatrain/experimental/soap_bpnn/tests/test_continue.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def test_continue(monkeypatch, tmp_path):
}
}
targets, _ = read_targets(OmegaConf.create(conf))

# systems in float64 are required for training
systems = [system.to(torch.float64) for system in systems]
dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]})

hypers = DEFAULT_HYPERS.copy()
Expand All @@ -63,6 +66,9 @@ def test_continue(monkeypatch, tmp_path):
checkpoint_dir=".",
)

# evaluation
systems = [system.to(torch.float32) for system in systems]

# Predict on the first five systems
output_before = model_before(
systems[:5], {"mtt::U0": model_before.outputs["mtt::U0"]}
Expand Down
26 changes: 16 additions & 10 deletions src/metatrain/experimental/soap_bpnn/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,18 @@ def test_regression_init():
)

expected_output = torch.tensor(
[[-0.03860], [0.11137], [0.09112], [-0.05634], [-0.02549]]
[
[-0.038599025458],
[ 0.111374437809],
[ 0.091115802526],
[-0.056339077652],
[-0.025491207838]
]
)

# if you need to change the hardcoded values:
# torch.set_printoptions(precision=12)
# print(output["mtt::U0"].block().values)
torch.set_printoptions(precision=12)
print(output["mtt::U0"].block().values)

torch.testing.assert_close(
output["mtt::U0"].block().values, expected_output, rtol=1e-5, atol=1e-5
Expand Down Expand Up @@ -100,17 +106,17 @@ def test_regression_train():

expected_output = torch.tensor(
[
[-40.592571258545],
[-56.522350311279],
[-76.571365356445],
[-77.384849548340],
[-93.445365905762],
[-0.106249026954],
[ 0.039981484413],
[-0.142682999372],
[-0.031701669097],
[-0.016210660338]
]
)

# if you need to change the hardcoded values:
# torch.set_printoptions(precision=12)
# print(output["mtt::U0"].block().values)
torch.set_printoptions(precision=12)
print(output["mtt::U0"].block().values)

torch.testing.assert_close(
output["mtt::U0"].block().values, expected_output, rtol=1e-5, atol=1e-5
Expand Down
6 changes: 5 additions & 1 deletion src/metatrain/utils/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def train_model(
dtype = datasets[0][0]["system"].positions.dtype
if dtype != torch.float64:
raise ValueError(
"The composition model only supports float64 during training."
"The composition model only supports float64 during training. "
f"Got dtype: {dtype}."
)

Expand Down Expand Up @@ -194,6 +194,8 @@ def forward(
device = systems[0].positions.device

for output_name in outputs:
if output_name.startswith("mtt::aux::"):
continue
if output_name not in self.output_to_output_index:
raise ValueError(
f"output key {output_name} is not supported by this composition "
Expand All @@ -210,6 +212,8 @@ def forward(
# number of atoms per atomic type.
targets_out: Dict[str, TensorMap] = {}
for target_key, target in outputs.items():
if target_key.startswith("mtt::aux::"):
continue
weights = self.weights[self.output_to_output_index[target_key]]
targets_list = []
sample_values: List[List[int]] = []
Expand Down

0 comments on commit 9a7718b

Please sign in to comment.