Skip to content

Commit

Permalink
Fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Nov 11, 2024
1 parent fb4a24d commit 239428d
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 10 deletions.
10 changes: 7 additions & 3 deletions src/metatrain/utils/data/readers/metatensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,16 @@ def read_energy(target: DictConfig) -> Tuple[List[TensorMap], TargetInfo]:
"Either all targets should have strain gradients or none."
)

add_position_gradients = all(has_position_gradients)
add_strain_gradients = all(has_strain_gradients)
add_position_gradients = target["forces"]
add_strain_gradients = target["stress"] or target["virial"]
print(add_position_gradients, add_strain_gradients)
target_info = get_energy_target_info(
target, add_position_gradients, add_strain_gradients
)

print(target_info.layout.block())
print(tensor_maps[0].block())

# now check all the expected metadata (from target_info.layout) matches
# the actual metadata in the tensor maps
_check_tensor_maps_metadata(tensor_maps, target_info.layout)
Expand Down Expand Up @@ -92,7 +96,7 @@ def _check_tensor_maps_metadata(tensor_maps: List[TensorMap], layout: TensorMap)
)
for key in layout.keys:
block = tensor_map.block(key)
block_from_layout = tensor_map.block(key)
block_from_layout = layout.block(key)
if block.samples.names != block_from_layout.samples.names:
raise ValueError(

Check warning on line 101 in src/metatrain/utils/data/readers/metatensor.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/data/readers/metatensor.py#L101

Added line #L101 was not covered by tests
f"Unexpected samples in metatensor targets at index {i}: "
Expand Down
4 changes: 2 additions & 2 deletions src/metatrain/utils/data/target_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def _get_cartesian_target_info(target: DictConfig) -> TargetInfo:
components = [Labels(["xyz"], torch.arange(3).reshape(-1, 1))]
else:
components = []
for component in range(target["type"][cartesian_key]["rank"]):
for component in range(1, target["type"][cartesian_key]["rank"] + 1):
components.append(
Labels(
names=[f"xyz_{component}"],
Expand Down Expand Up @@ -357,7 +357,7 @@ def _get_spherical_target_info(target: DictConfig) -> TargetInfo:
Labels(
names=["o3_mu"],
values=torch.arange(
2 * irrep["o3_lambda"] + 1, dtype=torch.int32
-irrep["o3_lambda"], irrep["o3_lambda"] + 1, dtype=torch.int32
).reshape(-1, 1),
)
]
Expand Down
19 changes: 14 additions & 5 deletions tests/utils/data/test_readers_metatensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def spherical_tensor_map():
values=torch.arange(0, 1, dtype=torch.int32).reshape(-1, 1),
),
],
properties=Labels.single(),
properties=Labels.range("properties", 1),
),
TensorBlock(
values=torch.rand(1, 5, 1, dtype=torch.float64),
Expand All @@ -82,7 +82,7 @@ def spherical_tensor_map():
values=torch.arange(-2, 3, dtype=torch.int32).reshape(-1, 1),
),
],
properties=Labels.single(),
properties=Labels.range("properties", 1),
),
],
)
Expand All @@ -109,7 +109,7 @@ def cartesian_tensor_map():
values=torch.arange(0, 3, dtype=torch.int32).reshape(-1, 1),
),
],
properties=Labels.single(),
properties=Labels.range("properties", 1),
),
],
)
Expand Down Expand Up @@ -234,7 +234,7 @@ def test_read_generic_cartesian(monkeypatch, tmpdir, cartesian_tensor_map):
assert metatensor.torch.equal(tensor_map, cartesian_tensor_map)


def test_read_error(monkeypatch, tmpdir):
def test_read_errors(monkeypatch, tmpdir, energy_tensor_map):
monkeypatch.chdir(tmpdir)

numpy_array = np.zeros((2, 2))
Expand All @@ -249,10 +249,19 @@ def test_read_error(monkeypatch, tmpdir):
"type": "scalar",
"per_atom": False,
"num_properties": 1,
"forces": False,
"forces": True,
"stress": False,
"virial": False,
}

with pytest.raises(ValueError, match="Failed to read"):
read_energy(OmegaConf.create(conf))

torch.save(
[energy_tensor_map, energy_tensor_map],
"energy.mts",
)
conf["read_from"] = "energy.mts"

with pytest.raises(ValueError, match="Unexpected gradients"):
read_energy(OmegaConf.create(conf))

0 comments on commit 239428d

Please sign in to comment.