Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add debug flag and architecture error message #80

Merged
merged 2 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions src/metatensor/models/__main__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""The main entry point for the metatensor-models interface."""
"""The main entry point for the metatensor-models command line interface."""

import argparse
import sys
import traceback

from . import __version__
from .cli import eval_model, export_model, train_model
Expand All @@ -16,14 +17,20 @@ def main():
formatter_class=argparse.RawTextHelpFormatter,
)

if len(sys.argv) < 2:
ap.error("You must specify a sub-command")

ap.add_argument(
"--version",
action="version",
version=f"metatensor-models {__version__}",
)

if len(sys.argv) < 2:
ap.error("You must specify a sub-command")
ap.add_argument(
"--debug",
action="store_true",
help="Run with debug options.",
)

# Add sub-parsers
subparser = ap.add_subparsers(help="sub-command help")
Expand All @@ -33,15 +40,22 @@ def main():

args = ap.parse_args()
callable = args.__dict__.pop("callable")

if callable == "eval_model":
eval_model(**args.__dict__)
elif callable == "export_model":
export_model(**args.__dict__)
elif callable == "train_model":
train_model(**args.__dict__)
else:
raise ValueError("internal error when selecting a sub-command.")
debug = args.__dict__.pop("debug")

try:
if callable == "eval_model":
eval_model(**args.__dict__)
elif callable == "export_model":
export_model(**args.__dict__)
elif callable == "train_model":
train_model(**args.__dict__)
else:
raise ValueError("internal error when selecting a sub-command.")
except Exception as e:
if debug:
traceback.print_exc()
else:
sys.exit(f"\033[31mERROR: {e}\033[0m") # format error in red!


if __name__ == "__main__":
Expand Down
46 changes: 30 additions & 16 deletions src/metatensor/models/cli/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from ..utils.compute_loss import compute_model_loss
from ..utils.data import collate_fn, read_structures, read_targets, write_predictions
from ..utils.errors import ArchitectureError
from ..utils.export import is_exported
from ..utils.extract_targets import get_outputs_dict
from ..utils.info import finalize_aggregated_info, update_aggregated_info
Expand Down Expand Up @@ -109,7 +110,13 @@ def _eval_targets(model, dataset: Union[_BaseDataset, torch.utils.data.Subset])
finalized_info = finalize_aggregated_info(aggregated_info)

energy_counter = 0
for output in model.capabilities().outputs.values():

try:
outputs_capabilities = model.capabilities().outputs
except Exception as e:
raise ArchitectureError(e)

for output in outputs_capabilities.values():
if output.quantity == "energy":
energy_counter += 1
if energy_counter == 1:
Expand All @@ -123,7 +130,7 @@ def _eval_targets(model, dataset: Union[_BaseDataset, torch.utils.data.Subset])
if key.endswith("_positions_gradients"):
# check if this is a force
target_name = key[: -len("_positions_gradients")]
if model.capabilities().outputs[target_name].quantity == "energy":
if outputs_capabilities[target_name].quantity == "energy":
# if this is a force, replace the ugly name with "force"
if only_one_energy:
new_key = "force"
Expand All @@ -132,9 +139,8 @@ def _eval_targets(model, dataset: Union[_BaseDataset, torch.utils.data.Subset])
elif key.endswith("_displacement_gradients"):
# check if this is a virial/stress
target_name = key[: -len("_displacement_gradients")]
if model.capabilities().outputs[target_name].quantity == "energy":
# if this is a virial/stress,
# replace the ugly name with "virial/stress"
if outputs_capabilities[target_name].quantity == "energy":
# if this is a virial/stress, replace the ugly name with "virial/stress"
if only_one_energy:
new_key = "virial/stress"
else:
Expand Down Expand Up @@ -171,26 +177,34 @@ def eval_model(
filename=options["structures"]["read_from"],
fileformat=options["structures"]["file_format"],
)

# Predict targets
if hasattr(options, "targets"):
eval_targets = read_targets(options["targets"])
eval_dataset = Dataset(structure=eval_structures, energy=eval_targets["energy"])
_eval_targets(model, eval_dataset)

# Predict structures
# TODO: batch this
# TODO: add forces/stresses/virials if requested
if not hasattr(options, "targets"):
# otherwise, the NLs will have been computed for the RMSE calculations above
else:
# TODO: batch this
# TODO: add forces/stresses/virials if requested
# Attach neighbors list to structures. This step is only required if no targets
# are present. Otherwise, the neighbors list have been already attached in
# `_eval_targets`.
eval_structures = [
get_system_with_neighbors_lists(
structure, model.requested_neighbors_lists()
)
for structure in eval_structures
]
eval_options = ModelEvaluationOptions(
length_unit="", # this is only needed for unit conversions in MD engines
outputs=model.capabilities().outputs,
)
predictions = model(eval_structures, eval_options, check_consistency=True)

# Predict structures
try:
# `length_unit` is only required for unit conversions in MD engines and
# superflous here.
eval_options = ModelEvaluationOptions(
length_unit="", outputs=model.capabilities().outputs
)
predictions = model(eval_structures, eval_options, check_consistency=True)
except Exception as e:
raise ArchitectureError(e)

write_predictions(output, predictions, eval_structures)
22 changes: 13 additions & 9 deletions src/metatensor/models/cli/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .. import CONFIG_PATH
from ..utils.data import get_all_species, read_structures, read_targets
from ..utils.data.dataset import _train_test_random_split
from ..utils.errors import ArchitectureError
from ..utils.export import export
from ..utils.model_io import save_model
from ..utils.omegaconf import check_units, expand_dataset_config
Expand Down Expand Up @@ -337,15 +338,18 @@ def _train_model_hydra(options: DictConfig) -> None:
)

logger.info("Calling architecture trainer")
model = architecture.train(
train_datasets=[train_dataset],
validation_datasets=[validation_dataset],
requested_capabilities=requested_capabilities,
hypers=OmegaConf.to_container(options["architecture"]),
continue_from=options["continue_from"],
output_dir=output_dir,
device_str=options["device"],
)
try:
model = architecture.train(
train_datasets=[train_dataset],
validation_datasets=[validation_dataset],
requested_capabilities=requested_capabilities,
hypers=OmegaConf.to_container(options["architecture"]),
continue_from=options["continue_from"],
output_dir=output_dir,
device_str=options["device"],
)
except Exception as e:
raise ArchitectureError(e)

save_model(model, f'{options["output_path"][:-3]}.ckpt')
export(model, options["output_path"])
Expand Down
14 changes: 10 additions & 4 deletions src/metatensor/models/utils/compute_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
register_autograd_neighbors,
)

from .errors import ArchitectureError
from .export import is_exported
from .loss import TensorMapDictLoss
from .output_gradient import compute_gradient
Expand Down Expand Up @@ -41,12 +42,17 @@ def compute_model_loss(

:returns: The loss as a scalar `torch.Tensor`.
"""
try:
device = next(model.parameters()).device
outputs_capabilities = _get_capabilities(model).outputs
except Exception as e:
raise ArchitectureError(e)

# Assert that all targets are within the model's capabilities:
if not set(targets.keys()).issubset(_get_capabilities(model).outputs.keys()):
if not set(targets.keys()).issubset(outputs_capabilities.keys()):
raise ValueError("Not all targets are within the model's capabilities.")

# Infer model device, move systems and targets to the same device:
device = next(model.parameters()).device
# Infer move systems and targets to the same device:
systems = [system.to(device=device) for system in systems]
targets = {key: target.to(device=device) for key, target in targets.items()}

Expand All @@ -56,7 +62,7 @@ def compute_model_loss(
energy_targets_that_require_strain_gradients = []
for target_name in targets.keys():
# Check if the target is an energy:
if _get_capabilities(model).outputs[target_name].quantity == "energy":
if outputs_capabilities[target_name].quantity == "energy":
energy_targets.append(target_name)
# Check if the energy requires gradients:
if targets[target_name].block().has_gradient("positions"):
Expand Down
21 changes: 21 additions & 0 deletions src/metatensor/models/utils/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
class ArchitectureError(Exception):
"""
Exception raised for errors originating from architectures

This exception should be raised when an error occurs within an architecture's
operation, indicating that the problem is not directly related to the
metatensor-models infrastructure but rather to the specific architecture being used.

:param exception: The original exception that was caught, which led to raising this
custom exception.
:type exception: The exception message includes the message of the original
exception, followed by a note emphasizing that the error likely originates from
an architecture.
"""

def __init__(self, exception):
super().__init__(
"The error below most likely originates from an architecture. If you think "
"this is a bug, please contact its maintainer (see the architecture's "
f"documentation).\n\n{exception}"
)
5 changes: 5 additions & 0 deletions tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,8 @@ def test_available_modules(self, module):
def test_extra_options(self, args):
"""Test extra options."""
subprocess.check_call(["metatensor-models", "--" + args])

@pytest.mark.parametrize("args", ("version", "help"))
def test_debug_flag(self, args):
"""Test that even if debug flag is set commands run normal."""
subprocess.check_call(["metatensor-models", "--debug", "train", "-h"])
24 changes: 20 additions & 4 deletions tests/cli/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,9 @@ def test_model_consistency_with_seed(
if seed is not None and seed < 0:
with pytest.raises(SystemExit):
train_model(options)
captured = capsys.readouterr()
assert "should be a positive number or None." in captured.out

captured = capsys.readouterr()
assert "should be a positive number or None." in captured.err
return

train_model(options, output="model1.pt")
Expand Down Expand Up @@ -191,8 +192,23 @@ def test_error_base_precision(options, monkeypatch, tmp_path, capsys):

with pytest.raises(SystemExit):
train_model(options)
captured = capsys.readouterr()
assert "Only 64, 32 or 16 are possible values for" in captured.out

captured = capsys.readouterr()
assert "Only 64, 32 or 16 are possible values for" in captured.err


def test_architectur_error(options, monkeypatch, tmp_path, capsys):
"""Test an error raise if there is problem wth the architecture."""
monkeypatch.chdir(tmp_path)
shutil.copy(DATASET_PATH, "qm9_reduced_100.xyz")

options["architecture"]["model"] = OmegaConf.create({"soap": {"cutoff": -1}})

with pytest.raises(SystemExit):
train_model(options)

captured = capsys.readouterr()
assert "likely originates from an architecture" in captured.err


def test_check_architecture_name():
Expand Down
12 changes: 12 additions & 0 deletions tests/utils/test_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import pytest

from metatensor.models.utils.errors import ArchitectureError


def test_architecture_error():
match = "The error below most likely originates from an architecture"
with pytest.raises(ArchitectureError, match=match):
try:
raise ValueError("An example error from the architecture")
except Exception as e:
raise ArchitectureError(e)
Loading