From 21d60982e0b4d4dedc1883a2416ad84ab366184d Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Mon, 11 Nov 2024 13:23:11 +0100 Subject: [PATCH] One more test --- tests/utils/data/test_readers_metatensor.py | 43 +++++++++++++++++++++ tests/utils/data/test_target_info.py | 10 +++-- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/tests/utils/data/test_readers_metatensor.py b/tests/utils/data/test_readers_metatensor.py index e38ebc4d..b03470b4 100644 --- a/tests/utils/data/test_readers_metatensor.py +++ b/tests/utils/data/test_readers_metatensor.py @@ -29,6 +29,24 @@ def energy_tensor_map(): ) +@pytest.fixture +def scalar_tensor_map(): + return TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.rand(2, 10, dtype=torch.float64), + samples=Labels( + names=["system", "atom"], + values=torch.tensor([[0, 0], [0, 1]], dtype=torch.int32), + ), + components=[], + properties=Labels.range("properties", 10), + ) + ], + ) + + @pytest.fixture def spherical_tensor_map(): return TensorMap( @@ -129,6 +147,31 @@ def test_read_energy(monkeypatch, tmpdir, energy_tensor_map): assert metatensor.torch.equal(tensor_map, energy_tensor_map) +def test_read_generic_scalar(monkeypatch, tmpdir, scalar_tensor_map): + monkeypatch.chdir(tmpdir) + + torch.save( + [scalar_tensor_map, scalar_tensor_map], + "generic.mts", + ) + + conf = { + "quantity": "generic", + "read_from": "generic.mts", + "reader": "metatensor", + "keys": ["scalar"], + "per_atom": True, + "unit": "unit", + "type": "scalar", + "num_properties": 10, + } + + tensor_maps, target_info = read_generic(OmegaConf.create(conf)) + + for tensor_map in tensor_maps: + assert metatensor.torch.equal(tensor_map, scalar_tensor_map) + + def test_read_generic_spherical(monkeypatch, tmpdir, spherical_tensor_map): monkeypatch.chdir(tmpdir) diff --git a/tests/utils/data/test_target_info.py b/tests/utils/data/test_target_info.py index 8e98cb6c..a07c2722 100644 --- a/tests/utils/data/test_target_info.py +++ b/tests/utils/data/test_target_info.py @@ -59,10 +59,12 @@ def spherical_target_config() -> DictConfig: "per_atom": False, "num_properties": 1, "type": { - "spherical": [ - {"o3_lambda": 0, "o3_sigma": 1}, - {"o3_lambda": 2, "o3_sigma": 1}, - ], + "spherical": { + "irreps": [ + {"o3_lambda": 0, "o3_sigma": 1}, + {"o3_lambda": 2, "o3_sigma": 1}, + ], + }, }, } )