diff --git a/src/metatrain/__main__.py b/src/metatrain/__main__.py index d122c0a08..23b6c173f 100644 --- a/src/metatrain/__main__.py +++ b/src/metatrain/__main__.py @@ -1,7 +1,6 @@ """The main entry point for the metatrain command line interface.""" import argparse -import importlib import logging import os import sys @@ -9,14 +8,14 @@ from datetime import datetime from pathlib import Path -import metatensor.torch -from omegaconf import OmegaConf - from . import __version__ -from .cli.eval import _add_eval_model_parser, eval_model -from .cli.export import _add_export_model_parser, export_model -from .cli.train import _add_train_model_parser, train_model -from .utils.architectures import check_architecture_name +from .cli.eval import _add_eval_model_parser, _prepare_eval_model_args, eval_model +from .cli.export import ( + _add_export_model_parser, + _prepare_export_model_args, + export_model, +) +from .cli.train import _add_train_model_parser, _prepare_train_model_args, train_model from .utils.logging import setup_logging @@ -71,59 +70,48 @@ def main(): args = ap.parse_args() callable = args.__dict__.pop("callable") debug = args.__dict__.pop("debug") - logfile = None + log_file = None + error_file = Path("error.log") if debug: level = logging.DEBUG else: level = logging.INFO - if callable == "eval_model": - args.__dict__["model"] = metatensor.torch.atomistic.load_atomistic_model( - path=args.__dict__.pop("path"), - extensions_directory=args.__dict__.pop("extensions_directory"), - ) - elif callable == "export_model": - architecture_name = args.__dict__.pop("architecture_name") - check_architecture_name(architecture_name) - architecture = importlib.import_module(f"metatrain.{architecture_name}") - - args.__dict__["model"] = architecture.__model__.load_checkpoint( - args.__dict__.pop("path") - ) - elif callable == "train_model": - # define and create `checkpoint_dir` based on current directory and date/time + if callable == "train_model": + # define and create `checkpoint_dir` based on current directory, date and time checkpoint_dir = _datetime_output_path(now=datetime.now()) os.makedirs(checkpoint_dir) - args.__dict__["checkpoint_dir"] = checkpoint_dir - - # save log to file - logfile = checkpoint_dir / "train.log" + args.checkpoint_dir = checkpoint_dir - # merge/override file options with command line options - override_options = args.__dict__.pop("override_options") - if override_options is None: - override_options = {} - - args.options = OmegaConf.merge(args.options, override_options) - else: - raise ValueError("internal error when selecting a sub-command.") + log_file = checkpoint_dir / "train.log" + error_file = checkpoint_dir / error_file - with setup_logging(logger, logfile=logfile, level=level): + with setup_logging(logger, log_file=log_file, level=level): try: if callable == "eval_model": + _prepare_eval_model_args(args) eval_model(**args.__dict__) elif callable == "export_model": + _prepare_export_model_args(args) export_model(**args.__dict__) elif callable == "train_model": + _prepare_train_model_args(args) train_model(**args.__dict__) else: - raise ValueError("internal error when selecting a sub-command.") + raise ValueError("internal error when selecting a sub-command") except Exception as err: - if debug: - traceback.print_exc() - else: - sys.exit(str(err)) + logging.error( + "If the error message below is unclear, please help us improve it by " + "opening an issue at https://github.com/lab-cosmo/metatrain/issues. " + "When opening the issue, please include the full traceback log from " + f"{str(error_file.absolute().resolve())!r}. Thank you!\n\n{err}" + ) + + with open(error_file, "w") as f: + f.write(traceback.format_exc()) + + sys.exit(1) if __name__ == "__main__": diff --git a/src/metatrain/cli/eval.py b/src/metatrain/cli/eval.py index 57e65a911..0be2ff1d4 100644 --- a/src/metatrain/cli/eval.py +++ b/src/metatrain/cli/eval.py @@ -55,7 +55,7 @@ def _add_eval_model_parser(subparser: argparse._SubParsersAction) -> None: ) parser.add_argument( "options", - type=OmegaConf.load, + type=str, help="Eval options file to define a dataset for evaluation.", ) parser.add_argument( @@ -81,6 +81,15 @@ def _add_eval_model_parser(subparser: argparse._SubParsersAction) -> None: ) +def _prepare_eval_model_args(args: argparse.Namespace) -> None: + """Prepare arguments for eval_model.""" + args.options = OmegaConf.load(args.options) + args.model = metatensor.torch.atomistic.load_atomistic_model( + path=args.__dict__.pop("path"), + extensions_directory=args.__dict__.pop("extensions_directory"), + ) + + def _concatenate_tensormaps( tensormap_dict_list: List[Dict[str, TensorMap]] ) -> Dict[str, TensorMap]: diff --git a/src/metatrain/cli/export.py b/src/metatrain/cli/export.py index 642ac0b36..e84152bf8 100644 --- a/src/metatrain/cli/export.py +++ b/src/metatrain/cli/export.py @@ -1,11 +1,12 @@ import argparse +import importlib import logging from pathlib import Path from typing import Any, Union import torch -from ..utils.architectures import find_all_architectures +from ..utils.architectures import check_architecture_name, find_all_architectures from ..utils.export import is_exported from ..utils.io import check_suffix from .formatter import CustomHelpFormatter @@ -53,6 +54,15 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None: ) +def _prepare_export_model_args(args: argparse.Namespace) -> None: + """Prepare arguments for export_model.""" + architecture_name = args.__dict__.pop("architecture_name") + check_architecture_name(architecture_name) + architecture = importlib.import_module(f"metatrain.{architecture_name}") + + args.model = architecture.__model__.load_checkpoint(args.__dict__.pop("path")) + + def export_model(model: Any, output: Union[Path, str] = "exported-model.pt") -> None: """Export a trained model to allow it to make predictions. diff --git a/src/metatrain/cli/train.py b/src/metatrain/cli/train.py index 77d10a731..28e57d9c6 100644 --- a/src/metatrain/cli/train.py +++ b/src/metatrain/cli/train.py @@ -56,7 +56,7 @@ def _add_train_model_parser(subparser: argparse._SubParsersAction) -> None: parser.add_argument( "options", - type=OmegaConf.load, + type=str, help="Options file", ) parser.add_argument( @@ -85,6 +85,17 @@ def _add_train_model_parser(subparser: argparse._SubParsersAction) -> None: ) +def _prepare_train_model_args(args: argparse.Namespace) -> None: + """Prepare arguments for train_model.""" + args.options = OmegaConf.load(args.options) + # merge/override file options with command line options + override_options = args.__dict__.pop("override_options") + if override_options is None: + override_options = {} + + args.options = OmegaConf.merge(args.options, override_options) + + def train_model( options: Union[DictConfig, Dict], output: str = "model.pt", diff --git a/src/metatrain/utils/data/readers/readers.py b/src/metatrain/utils/data/readers/readers.py index 5950fafc0..3ea309ebd 100644 --- a/src/metatrain/utils/data/readers/readers.py +++ b/src/metatrain/utils/data/readers/readers.py @@ -204,7 +204,7 @@ def read_targets( fileformat=target["forces"]["file_format"], dtype=dtype, ) - except KeyError: + except Exception: logger.warning( f"No Forces found in section {target_key!r}. " "Continue without forces!" @@ -232,7 +232,7 @@ def read_targets( fileformat=target["stress"]["file_format"], dtype=dtype, ) - except KeyError: + except Exception: logger.warning( f"No Stress found in section {target_key!r}. " "Continue without stress!" @@ -255,7 +255,7 @@ def read_targets( fileformat=target["virial"]["file_format"], dtype=dtype, ) - except KeyError: + except Exception: logger.warning( f"No Virial found in section {target_key!r}. " "Continue without virial!" diff --git a/src/metatrain/utils/data/readers/targets/ase.py b/src/metatrain/utils/data/readers/targets/ase.py index ff24db756..fe4e10117 100644 --- a/src/metatrain/utils/data/readers/targets/ase.py +++ b/src/metatrain/utils/data/readers/targets/ase.py @@ -26,6 +26,12 @@ def read_energy_ase( blocks = [] for i_system, atoms in enumerate(frames): + if key not in atoms.info: + raise ValueError( + f"energy key {key!r} was not found in system {filename!r} at index " + f"{i_system}" + ) + values = torch.tensor([[atoms.info[key]]], dtype=dtype) samples = Labels(["system"], torch.tensor([[i_system]])) @@ -62,6 +68,13 @@ def read_forces_ase( blocks = [] for i_system, atoms in enumerate(frames): + + if key not in atoms.arrays: + raise ValueError( + f"forces key {key!r} was not found in system {filename!r} at index " + f"{i_system}" + ) + # We store forces as positions gradients which means we invert the sign values = -torch.tensor(atoms.arrays[key], dtype=dtype) values = values.reshape(-1, 3, 1) @@ -137,8 +150,7 @@ def _read_virial_stress_ase( :param is_virial: if target values are stored as stress or virials. :param dtype: desired data type of returned tensor - :returns: - TensorMap containing the given information + :returns: TensorMap containing the given information """ frames = ase.io.read(filename, ":") @@ -151,6 +163,16 @@ def _read_virial_stress_ase( blocks = [] for i_system, atoms in enumerate(frames): + if key not in atoms.info: + if is_virial: + target_name = "virial" + else: + target_name = "stress" + + raise ValueError( + f"{target_name} key {key!r} was not found in system {filename!r} at " + f"index {i_system}" + ) values = torch.tensor(atoms.info[key].tolist(), dtype=dtype) diff --git a/src/metatrain/utils/errors.py b/src/metatrain/utils/errors.py index 82c9ec3be..c7abf2938 100644 --- a/src/metatrain/utils/errors.py +++ b/src/metatrain/utils/errors.py @@ -12,7 +12,8 @@ class ArchitectureError(Exception): def __init__(self, exception): super().__init__( - "The error below most likely originates from an architecture. If you think " - "this is a bug, please contact its maintainer (see the architecture's " - f"documentation).\n\n{exception}" + f"{exception}\n\nThe error above most likely originates from an " + "architecture.\n\nIf you think this is a bug, please contact its " + "maintainer (see the architecture's documentation) and include the full " + "traceback error.log." ) diff --git a/src/metatrain/utils/logging.py b/src/metatrain/utils/logging.py index 05a5e5e71..3cb56a218 100644 --- a/src/metatrain/utils/logging.py +++ b/src/metatrain/utils/logging.py @@ -144,8 +144,8 @@ def _get_digits(value: float) -> Tuple[int, int]: @contextlib.contextmanager def setup_logging( - logobj: logging.Logger, - logfile: Optional[Union[str, Path]] = None, + log_obj: logging.Logger, + log_file: Optional[Union[str, Path]] = None, level: int = logging.WARNING, ): """Create a logging environment for a given ``logobj``. @@ -153,8 +153,8 @@ def setup_logging( Extracted and adjusted from github.com/MDAnalysis/mdacli/blob/main/src/mdacli/logger.py - :param logobj: A logging instance - :param logfile: Name of the log file + :param log_obj: A logging instance + :param log_file: Name of the log file :param level: Set the root logger level to the specified level. If for example set to :py:obj:`logging.DEBUG` detailed debug logs inludcing filename and function name are displayed. For :py:obj:`logging.INFO` only the message logged from @@ -173,21 +173,22 @@ def setup_logging( stream_handler.setFormatter(formatter) handlers.append(stream_handler) - if logfile: - logfile = check_suffix(filename=logfile, suffix=".log") - file_handler = logging.FileHandler(filename=str(logfile), encoding="utf-8") + if log_file: + log_file = check_suffix(filename=log_file, suffix=".log") + file_handler = logging.FileHandler(filename=str(log_file), encoding="utf-8") file_handler.setFormatter(formatter) handlers.append(file_handler) logging.basicConfig(format=format, handlers=handlers, level=level, style="{") - if logfile: - logobj.info(f"This log is also available in {str(logfile)!r}.") + if log_file: + abs_path = str(Path(log_file).absolute().resolve()) + log_obj.info(f"This log is also available at {abs_path!r}.") else: - logobj.info("Logging to file is disabled.") + log_obj.info("Logging to file is disabled.") for handler in handlers: - logobj.addHandler(handler) + log_obj.addHandler(handler) yield @@ -195,4 +196,4 @@ def setup_logging( for handler in handlers: handler.flush() handler.close() - logobj.removeHandler(handler) + log_obj.removeHandler(handler) diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 083a67abe..5a5feb550 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -1,8 +1,10 @@ """Tests for argument parsing.""" +import glob import shutil import subprocess from pathlib import Path +from subprocess import CalledProcessError from typing import List import pytest @@ -103,3 +105,36 @@ def get_completion_suggestions(partial_word: str) -> List[str]: def test_subcommand_completion(partial_word, expected_completion): """Test that expected subcommand completion matches.""" assert set(get_completion_suggestions(partial_word)) == set(expected_completion) + + +@pytest.mark.parametrize("subcommand", ["train", "eval"]) +def test_error(subcommand, capfd, monkeypatch, tmp_path): + """Test expected display of errors to stdout and log files.""" + monkeypatch.chdir(tmp_path) + + command = ["mtt", subcommand] + if subcommand == "eval": + command += ["model.pt"] + + command += ["foo.yaml"] + + with pytest.raises(CalledProcessError): + subprocess.check_call(command) + + stdout_log = capfd.readouterr().out + + if subcommand == "train": + error_glob = glob.glob("outputs/*/*/error.log") + error_file = error_glob[0] + else: + error_file = "error.log" + + error_file = str(Path(error_file).absolute().resolve()) + + with open(error_file) as f: + error_log = f.read() + + print(error_file) + assert f"please include the full traceback log from {error_file!r}" in stdout_log + assert "No such file or directory" in stdout_log + assert "Traceback" in error_log diff --git a/tests/cli/test_train_model.py b/tests/cli/test_train_model.py index abd3cc579..f255f358e 100644 --- a/tests/cli/test_train_model.py +++ b/tests/cli/test_train_model.py @@ -69,19 +69,18 @@ def test_train(capfd, monkeypatch, tmp_path, output): assert file_log == stdout_log - for logtext in [stdout_log, file_log]: - assert "This log is also available" - assert re.search(r"Random seed of this run is [1-9]\d*", logtext) - assert "Training dataset has size" - assert "Validation dataset has size" - assert "Test dataset has size" - assert "[INFO]" in logtext - assert "Epoch" in logtext - assert "loss" in logtext - assert "validation" in logtext - assert "train" in logtext - assert "energy" in logtext - assert "with index" not in logtext # index only printed for more than 1 dataset + assert "This log is also available" in stdout_log + assert re.search(r"Random seed of this run is [1-9]\d*", stdout_log) + assert "Training dataset has size" in stdout_log + assert "Validation dataset has size" in stdout_log + assert "Test dataset has size" in stdout_log + assert "[INFO]" in stdout_log + assert "Epoch" in stdout_log + assert "loss" in stdout_log + assert "validation" in stdout_log + assert "train" in stdout_log + assert "energy" in stdout_log + assert "with index" not in stdout_log # index only printed for more than 1 dataset @pytest.mark.parametrize( diff --git a/tests/utils/data/targets/test_targets_ase.py b/tests/utils/data/targets/test_targets_ase.py index ba2fcf848..29cee43b5 100644 --- a/tests/utils/data/targets/test_targets_ase.py +++ b/tests/utils/data/targets/test_targets_ase.py @@ -49,6 +49,29 @@ def test_read_energy_ase(monkeypatch, tmp_path): torch.testing.assert_close(result.values, expected) +@pytest.mark.parametrize( + "func, target_name", + [ + (read_energy_ase, "energy"), + (read_forces_ase, "forces"), + (read_virial_ase, "virial"), + (read_stress_ase, "stress"), + ], +) +def test_ase_key_errors(func, target_name, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + + filename = "systems.xyz" + + systems = ase_systems() + ase.io.write(filename, systems) + + match = f"{target_name} key 'foo' was not found in system {filename!r} at index 0" + + with pytest.raises(ValueError, match=match): + func(filename=filename, key="foo") + + def test_read_forces_ase(monkeypatch, tmp_path): monkeypatch.chdir(tmp_path) diff --git a/tests/utils/test_errors.py b/tests/utils/test_errors.py index 56156dea6..ac02abdf3 100644 --- a/tests/utils/test_errors.py +++ b/tests/utils/test_errors.py @@ -4,7 +4,7 @@ def test_architecture_error(): - match = "The error below most likely originates from an architecture" + match = "The error above most likely originates from an architecture" with pytest.raises(ArchitectureError, match=match): try: raise ValueError("An example error from the architecture") diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py index a13f4797a..8bb7a173c 100644 --- a/tests/utils/test_logging.py +++ b/tests/utils/test_logging.py @@ -40,7 +40,7 @@ def test_info_log(caplog, monkeypatch, tmp_path, capsys): caplog.set_level(logging.INFO) logger = logging.getLogger("test") - with setup_logging(logger, logfile="logfile.log", level=logging.INFO): + with setup_logging(logger, log_file="logfile.log", level=logging.INFO): logger.info("foo") logger.debug("A debug message") @@ -48,10 +48,10 @@ def test_info_log(caplog, monkeypatch, tmp_path, capsys): file_log = f.read() stdout_log = capsys.readouterr().out - - assert "This log is also available in 'logfile.log'" in caplog.text + log_path = str((tmp_path / "logfile.log").absolute()) assert file_log == stdout_log + assert f"This log is also available at '{log_path}'" in caplog.text for logtext in [stdout_log, file_log]: assert_log_entry(logtext, loglevel="INFO", message="foo") @@ -64,7 +64,7 @@ def test_debug_log(caplog, monkeypatch, tmp_path, capsys): caplog.set_level(logging.DEBUG) logger = logging.getLogger("test") - with setup_logging(logger, logfile="logfile.log", level=logging.DEBUG): + with setup_logging(logger, log_file="logfile.log", level=logging.DEBUG): logger.info("foo") logger.debug("A debug message") @@ -72,9 +72,10 @@ def test_debug_log(caplog, monkeypatch, tmp_path, capsys): file_log = f.read() stdout_log = capsys.readouterr().out + log_path = str((tmp_path / "logfile.log").absolute()) assert file_log == stdout_log - assert "This log is also available in 'logfile.log'" in caplog.text + assert f"This log is also available at '{log_path}'" in caplog.text for logtext in [stdout_log, file_log]: assert "foo" in logtext @@ -103,7 +104,7 @@ def test_metric_logger(caplog, capsys): } ] - with setup_logging(logger, logfile="logfile.log", level=logging.INFO): + with setup_logging(logger, log_file="logfile.log", level=logging.INFO): metric_logger = MetricLogger(logger, outputs, initial_metrics, names) metric_logger.log(initial_metrics)