Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error and logging messages #341

Merged
merged 7 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions src/metatrain/utils/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@
"""Returns the statistics of a dataset or subset as a string."""

dataset_len = len(dataset)
stats = f"Dataset of size {dataset_len}"
stats = f"Dataset containing {dataset_len} structures"
if dataset_len == 0:
return stats

Expand Down Expand Up @@ -389,17 +389,35 @@
or targets that are not present in the training set
"""
# Check that system `dtypes` are consistent within datasets
desired_dtype = train_datasets[0][0].system.positions.dtype
msg = f"`dtype` between datasets is inconsistent, found {desired_dtype} and "
desired_dtype = None
for train_dataset in train_datasets:
if len(train_dataset) == 0:
continue

Check warning on line 395 in src/metatrain/utils/data/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/data/dataset.py#L395

Added line #L395 was not covered by tests

actual_dtype = train_dataset[0].system.positions.dtype
if desired_dtype is None:
desired_dtype = actual_dtype

if actual_dtype != desired_dtype:
raise TypeError(f"{msg}{actual_dtype} found in `train_datasets`")
raise TypeError(
"`dtype` between datasets is inconsistent, "
f"found {desired_dtype} and {actual_dtype} in training datasets"
)

for val_dataset in val_datasets:
if len(val_dataset) == 0:
continue

Check warning on line 409 in src/metatrain/utils/data/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/data/dataset.py#L409

Added line #L409 was not covered by tests

actual_dtype = val_dataset[0].system.positions.dtype

if desired_dtype is None:
desired_dtype = actual_dtype

Check warning on line 414 in src/metatrain/utils/data/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/data/dataset.py#L414

Added line #L414 was not covered by tests

if actual_dtype != desired_dtype:
raise TypeError(f"{msg}{actual_dtype} found in `val_datasets`")
raise TypeError(
"`dtype` between datasets is inconsistent, "
f"found {desired_dtype} and {actual_dtype} in validation datasets"
)

# Get all targets in the training and validation sets:
train_targets = get_all_targets(train_datasets)
Expand Down
15 changes: 11 additions & 4 deletions src/metatrain/utils/data/readers/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,20 @@
from metatensor.torch.atomistic import System, systems_to_torch


def _wrapped_ase_io_read(filename):
try:
return ase.io.read(filename, ":")
except Exception as e:
raise ValueError(f"Failed to read '{filename}' with ASE: {e}") from e

Check warning on line 14 in src/metatrain/utils/data/readers/ase.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/data/readers/ase.py#L13-L14

Added lines #L13 - L14 were not covered by tests


def read_systems_ase(filename: str) -> List[System]:
"""Store system informations using ase.

:param filename: name of the file to read
:returns: A list of systems
"""
return systems_to_torch(ase.io.read(filename, ":"), dtype=torch.float64)
return systems_to_torch(_wrapped_ase_io_read(filename), dtype=torch.float64)


def read_energy_ase(filename: str, key: str) -> List[TensorBlock]:
Expand All @@ -23,7 +30,7 @@
:param key: target value key name to be parsed from the file.
:returns: TensorMap containing the energies
"""
frames = ase.io.read(filename, ":")
frames = _wrapped_ase_io_read(filename)

properties = Labels("energy", torch.tensor([[0]]))

Expand Down Expand Up @@ -57,7 +64,7 @@
:param key: target value key name to be parsed from the file.
:returns: TensorMap containing the forces
"""
frames = ase.io.read(filename, ":")
frames = _wrapped_ase_io_read(filename)

components = [Labels(["xyz"], torch.arange(3).reshape(-1, 1))]
properties = Labels("energy", torch.tensor([[0]]))
Expand Down Expand Up @@ -117,7 +124,7 @@
def _read_virial_stress_ase(
filename: str, key: str, is_virial: bool = True
) -> List[TensorBlock]:
frames = ase.io.read(filename, ":")
frames = _wrapped_ase_io_read(filename)

samples = Labels(["sample"], torch.tensor([[0]]))
components = [
Expand Down
57 changes: 27 additions & 30 deletions src/metatrain/utils/data/readers/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
) -> List[Any]:
if reader is None:
try:
filesuffix = Path(filename).suffix
reader = DEFAULT_READER[filesuffix]
file_suffix = Path(filename).suffix
reader = DEFAULT_READER[file_suffix]
except KeyError:
raise ValueError(
f"File extension {filesuffix!r} is not linked to a default reader "
f"File extension {file_suffix!r} is not linked to a default reader "
"library. You can try reading it by setting a specific 'reader' from "
f"the known ones: {', '.join(AVAILABLE_READERS)} "
)
Expand Down Expand Up @@ -171,7 +171,7 @@
This function uses subfunctions like :func:`read_energy` to parse the requested
target quantity. Currently only `energy` is a supported target property. But, within
the `energy` section gradients such as `forces`, the `stress` or the `virial` can be
added. Other gradients are silentlty irgnored.
added. Other gradients are silently ignored.

:param conf: config containing the keys for what should be read.
:returns: Dictionary containing a list of TensorMaps for each target section in the
Expand All @@ -191,13 +191,19 @@
for target_key, target in conf.items():
target_info_gradients: List[str] = []

if target_key not in standard_outputs_list and not target_key.startswith(
"mtt::"
):
raise ValueError(
f"Target names must either be one of {standard_outputs_list} "
"or start with `mtt::`."
)
is_standard_target = target_key in standard_outputs_list
if not is_standard_target and not target_key.startswith("mtt::"):
if target_key.lower() in ["force", "forces", "virial", "stress"]:
raise ValueError(

Check warning on line 197 in src/metatrain/utils/data/readers/readers.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/data/readers/readers.py#L197

Added line #L197 was not covered by tests
f"{target_key!r} should not be it's own top-level target, "
"but rather a sub-section of the 'energy' target"
)
else:
raise ValueError(
f"Target name ({target_key}) must either be one of "
f"{standard_outputs_list} or start with `mtt::`."
)

if target["quantity"] == "energy":
blocks = read_energy(
filename=target["read_from"],
Expand All @@ -213,14 +219,11 @@
reader=target["forces"]["reader"],
)
except Exception:
logger.warning(
f"No Forces found in section {target_key!r}. "
"Continue without forces!"
)
logger.warning(f"No forces found in section {target_key!r}.")
else:
logger.info(
f"Forces found in section {target_key!r}. Forces are taken for "
"training!"
f"Forces found in section {target_key!r}, "
"we will use this gradient to train the model"
)
for block, position_gradient in zip(blocks, position_gradients):
block.add_gradient(
Expand All @@ -230,7 +233,7 @@
target_info_gradients.append("positions")

if target["stress"] and target["virial"]:
raise ValueError("Cannot use stress and virial at the same time!")
raise ValueError("Cannot use stress and virial at the same time")

if target["stress"]:
try:
Expand All @@ -240,14 +243,11 @@
reader=target["stress"]["reader"],
)
except Exception:
logger.warning(
f"No Stress found in section {target_key!r}. "
"Continue without stress!"
)
logger.warning(f"No stress found in section {target_key!r}.")
else:
logger.info(
f"Stress found in section {target_key!r}. Stress is taken for "
f"training!"
f"Stress found in section {target_key!r}, "
"we will use this gradient to train the model"
)
for block, strain_gradient in zip(blocks, strain_gradients):
block.add_gradient(parameter="strain", gradient=strain_gradient)
Expand All @@ -262,14 +262,11 @@
reader=target["virial"]["reader"],
)
except Exception:
logger.warning(
f"No Virial found in section {target_key!r}. "
"Continue without virial!"
)
logger.warning(f"No virial found in section {target_key!r}.")
else:
logger.info(
f"Virial found in section {target_key!r}. Virial is taken for "
f"training!"
f"Virial found in section {target_key!r}, "
"we will use this gradient to train the model"
)
for block, strain_gradient in zip(blocks, strain_gradients):
block.add_gradient(parameter="strain", gradient=strain_gradient)
Expand Down
2 changes: 1 addition & 1 deletion tests/cli/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_train(capfd, monkeypatch, tmp_path, output):
assert "Training dataset:" in stdout_log
assert "Validation dataset:" in stdout_log
assert "Test dataset:" in stdout_log
assert "size 50" in stdout_log
assert "50 structures" in stdout_log
assert "mean " in stdout_log
assert "std " in stdout_log
assert "[INFO]" in stdout_log
Expand Down
11 changes: 6 additions & 5 deletions tests/utils/data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,25 +557,26 @@ def test_check_datasets():
# wrong dtype
systems_qm9_32bit = [system.to(dtype=torch.float32) for system in systems_qm9]
targets_qm9_32bit = {
k: [v.to(dtype=torch.float32) for v in l] for k, l in targets_qm9.items()
name: [tensor.to(dtype=torch.float32) for tensor in values]
for name, values in targets_qm9.items()
}
train_set_32_bit = Dataset.from_dict(
{"system": systems_qm9_32bit, **targets_qm9_32bit}
)

match = (
"`dtype` between datasets is inconsistent, found torch.float64 and "
"torch.float32 found in `val_datasets`"
"torch.float32 in validation datasets"
)
with pytest.raises(TypeError, match=match):
check_datasets([train_set], [train_set_32_bit])

match = (
"`dtype` between datasets is inconsistent, found torch.float64 and "
"torch.float32 found in `train_datasets`"
"torch.float32 in training datasets"
)
with pytest.raises(TypeError, match=match):
check_datasets([train_set, train_set_32_bit], [val_set])
check_datasets([train_set, train_set_32_bit], [])


def test_collate_fn():
Expand Down Expand Up @@ -651,7 +652,7 @@ def test_get_stats():
stats = get_stats(dataset, dataset_info)
stats_2 = get_stats(dataset_2, dataset_info)

assert "size 100" in stats
assert "100 structures" in stats
assert "mtt::U0" in stats
assert "energy" in stats_2
assert "mean " in stats
Expand Down
6 changes: 3 additions & 3 deletions tests/utils/data/test_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,12 @@ def test_read_targets_warnings(stress_dict, virial_dict, monkeypatch, tmp_path,
caplog.set_level(logging.WARNING)
read_targets(OmegaConf.create(conf)) # , slice_samples_by="system")

assert any(["No Forces found" in rec.message for rec in caplog.records])
assert any(["No forces found" in rec.message for rec in caplog.records])

if stress_dict:
assert any(["No Stress found" in rec.message for rec in caplog.records])
assert any(["No stress found" in rec.message for rec in caplog.records])
if virial_dict:
assert any(["No Virial found" in rec.message for rec in caplog.records])
assert any(["No virial found" in rec.message for rec in caplog.records])


def test_read_targets_error(monkeypatch, tmp_path):
Expand Down
6 changes: 3 additions & 3 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ deps =
build
check-manifest
twine
allowlist_externals = bash
allowlist_externals = rm
commands_pre =
bash -c "if [ -e {toxinidir}/dist/*tar.gz ]; then unlink {toxinidir}/dist/*.whl; fi"
bash -c "if [ -e {toxinidir}/dist/*tar.gz ]; then unlink {toxinidir}/dist/*.tar.gz; fi"
rm -f {toxinidir}/dist/*.whl
rm -f {toxinidir}/dist/*.tar.gz
commands =
python -m build
twine check dist/*.tar.gz dist/*.whl
Expand Down