From de1d9c6a52979f72b5a7f4e8630246613a43aa76 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 16 Aug 2024 20:27:57 +0200 Subject: [PATCH] Exclude `mtt::aux::` quantities from composition models --- src/metatrain/experimental/soap_bpnn/model.py | 2 ++ .../soap_bpnn/tests/test_continue.py | 6 +++++ .../soap_bpnn/tests/test_regression.py | 22 ++++++++++++------- src/metatrain/utils/composition.py | 6 ++++- 4 files changed, 27 insertions(+), 9 deletions(-) diff --git a/src/metatrain/experimental/soap_bpnn/model.py b/src/metatrain/experimental/soap_bpnn/model.py index c0c2a585..a3d0d4ea 100644 --- a/src/metatrain/experimental/soap_bpnn/model.py +++ b/src/metatrain/experimental/soap_bpnn/model.py @@ -285,6 +285,8 @@ def forward( # at evaluation, we also add the composition contributions composition_contributions = self.composition_model(systems, outputs) for name in return_dict: + if name.startswith("mtt::aux::"): + continue return_dict[name] = metatensor.torch.add( return_dict[name], composition_contributions[name], diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_continue.py b/src/metatrain/experimental/soap_bpnn/tests/test_continue.py index 9bd9b0e6..c1ae1eb4 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_continue.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_continue.py @@ -44,6 +44,9 @@ def test_continue(monkeypatch, tmp_path): } } targets, _ = read_targets(OmegaConf.create(conf)) + + # systems in float64 are required for training + systems = [system.to(torch.float64) for system in systems] dataset = Dataset({"system": systems, "mtt::U0": targets["mtt::U0"]}) hypers = DEFAULT_HYPERS.copy() @@ -63,6 +66,9 @@ def test_continue(monkeypatch, tmp_path): checkpoint_dir=".", ) + # evaluation + systems = [system.to(torch.float32) for system in systems] + # Predict on the first five systems output_before = model_before( systems[:5], {"mtt::U0": model_before.outputs["mtt::U0"]} diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_regression.py b/src/metatrain/experimental/soap_bpnn/tests/test_regression.py index 07e3871a..6f43f499 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_regression.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_regression.py @@ -39,12 +39,18 @@ def test_regression_init(): ) expected_output = torch.tensor( - [[-0.03860], [0.11137], [0.09112], [-0.05634], [-0.02549]] + [ + [0.515075504780], + [0.241114050150], + [0.196303755045], + [0.116181179881], + [-0.130566224456], + ] ) # if you need to change the hardcoded values: - # torch.set_printoptions(precision=12) - # print(output["mtt::U0"].block().values) + torch.set_printoptions(precision=12) + print(output["mtt::U0"].block().values) torch.testing.assert_close( output["mtt::U0"].block().values, expected_output, rtol=1e-5, atol=1e-5 @@ -100,11 +106,11 @@ def test_regression_train(): expected_output = torch.tensor( [ - [-40.592571258545], - [-56.522350311279], - [-76.571365356445], - [-77.384849548340], - [-93.445365905762], + [0.274103075266], + [0.297527313232], + [-0.024770291522], + [0.097502753139], + [0.048697717488], ] ) diff --git a/src/metatrain/utils/composition.py b/src/metatrain/utils/composition.py index 76a63b47..406a6584 100644 --- a/src/metatrain/utils/composition.py +++ b/src/metatrain/utils/composition.py @@ -123,7 +123,7 @@ def train_model( dtype = datasets[0][0]["system"].positions.dtype if dtype != torch.float64: raise ValueError( - "The composition model only supports float64 during training." + "The composition model only supports float64 during training. " f"Got dtype: {dtype}." ) @@ -194,6 +194,8 @@ def forward( device = systems[0].positions.device for output_name in outputs: + if output_name.startswith("mtt::aux::"): + continue if output_name not in self.output_to_output_index: raise ValueError( f"output key {output_name} is not supported by this composition " @@ -210,6 +212,8 @@ def forward( # number of atoms per atomic type. targets_out: Dict[str, TensorMap] = {} for target_key, target in outputs.items(): + if target_key.startswith("mtt::aux::"): + continue weights = self.weights[self.output_to_output_index[target_key]] targets_list = [] sample_values: List[List[int]] = []