Skip to content

Commit

Permalink
One more test
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Nov 11, 2024
1 parent b815435 commit 21d6098
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 4 deletions.
43 changes: 43 additions & 0 deletions tests/utils/data/test_readers_metatensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,24 @@ def energy_tensor_map():
)


@pytest.fixture
def scalar_tensor_map():
return TensorMap(
keys=Labels.single(),
blocks=[
TensorBlock(
values=torch.rand(2, 10, dtype=torch.float64),
samples=Labels(
names=["system", "atom"],
values=torch.tensor([[0, 0], [0, 1]], dtype=torch.int32),
),
components=[],
properties=Labels.range("properties", 10),
)
],
)


@pytest.fixture
def spherical_tensor_map():
return TensorMap(
Expand Down Expand Up @@ -129,6 +147,31 @@ def test_read_energy(monkeypatch, tmpdir, energy_tensor_map):
assert metatensor.torch.equal(tensor_map, energy_tensor_map)


def test_read_generic_scalar(monkeypatch, tmpdir, scalar_tensor_map):
monkeypatch.chdir(tmpdir)

torch.save(
[scalar_tensor_map, scalar_tensor_map],
"generic.mts",
)

conf = {
"quantity": "generic",
"read_from": "generic.mts",
"reader": "metatensor",
"keys": ["scalar"],
"per_atom": True,
"unit": "unit",
"type": "scalar",
"num_properties": 10,
}

tensor_maps, target_info = read_generic(OmegaConf.create(conf))

for tensor_map in tensor_maps:
assert metatensor.torch.equal(tensor_map, scalar_tensor_map)


def test_read_generic_spherical(monkeypatch, tmpdir, spherical_tensor_map):
monkeypatch.chdir(tmpdir)

Expand Down
10 changes: 6 additions & 4 deletions tests/utils/data/test_target_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,12 @@ def spherical_target_config() -> DictConfig:
"per_atom": False,
"num_properties": 1,
"type": {
"spherical": [
{"o3_lambda": 0, "o3_sigma": 1},
{"o3_lambda": 2, "o3_sigma": 1},
],
"spherical": {
"irreps": [
{"o3_lambda": 0, "o3_sigma": 1},
{"o3_lambda": 2, "o3_sigma": 1},
],
},
},
}
)
Expand Down

0 comments on commit 21d6098

Please sign in to comment.