Skip to content

Commit

Permalink
Write a traceback file
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Jun 12, 2024
1 parent 348a549 commit 8c46dc5
Show file tree
Hide file tree
Showing 13 changed files with 135 additions and 82 deletions.
66 changes: 26 additions & 40 deletions src/metatrain/__main__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
"""The main entry point for the metatrain command line interface."""

import argparse
import importlib
import logging
import os
import sys
import traceback
from datetime import datetime
from pathlib import Path

import metatensor.torch
from omegaconf import OmegaConf

from . import __version__
from .cli.eval import _add_eval_model_parser, eval_model
from .cli.export import _add_export_model_parser, export_model
from .cli.train import _add_train_model_parser, train_model
from .utils.architectures import check_architecture_name
from .cli.eval import _add_eval_model_parser, _prepare_eval_model_args, eval_model
from .cli.export import (
_add_export_model_parser,
_prepare_export_model_args,
export_model,
)
from .cli.train import _add_train_model_parser, _prepare_train_model_args, train_model
from .utils.logging import setup_logging


Expand Down Expand Up @@ -71,62 +70,49 @@ def main():
args = ap.parse_args()
callable = args.__dict__.pop("callable")
debug = args.__dict__.pop("debug")
logfile = None
log_file = None
error_file = Path("error.log")

if debug:
level = logging.DEBUG
else:
level = logging.INFO

if callable == "eval_model":
args.__dict__["model"] = metatensor.torch.atomistic.load_atomistic_model(
path=args.__dict__.pop("path"),
extensions_directory=args.__dict__.pop("extensions_directory"),
)
elif callable == "export_model":
architecture_name = args.__dict__.pop("architecture_name")
check_architecture_name(architecture_name)
architecture = importlib.import_module(f"metatrain.{architecture_name}")

args.__dict__["model"] = architecture.__model__.load_checkpoint(
args.__dict__.pop("path")
)
elif callable == "train_model":
# define and create `checkpoint_dir` based on current directory and date/time
if callable == "train_model":
# define and create `checkpoint_dir` based on current directory, date and time
checkpoint_dir = _datetime_output_path(now=datetime.now())
os.makedirs(checkpoint_dir)
args.__dict__["checkpoint_dir"] = checkpoint_dir

# save log to file
logfile = checkpoint_dir / "train.log"

# merge/override file options with command line options
override_options = args.__dict__.pop("override_options")
if override_options is None:
override_options = {}
args.checkpoint_dir = checkpoint_dir

args.options = OmegaConf.merge(args.options, override_options)
else:
raise ValueError("internal error when selecting a sub-command.")
log_file = checkpoint_dir / "train.log"
error_file = checkpoint_dir / error_file

with setup_logging(logger, logfile=logfile, level=level):
with setup_logging(logger, log_file=log_file, level=level):
try:
if callable == "eval_model":
_prepare_eval_model_args(args)
eval_model(**args.__dict__)
elif callable == "export_model":
_prepare_export_model_args(args)
export_model(**args.__dict__)
elif callable == "train_model":
_prepare_train_model_args(args)
train_model(**args.__dict__)
else:
raise ValueError("internal error when selecting a sub-command.")
raise ValueError("internal error when selecting a sub-command")
except Exception as err:
logging.error({traceback.format_exc()})
logging.error(
"If the error message below is unclear, please help us improve it by "
"opening an issue at https://github.com/lab-cosmo/metatrain/issues. "
f"Thank you!\n\n{type(err).__name__}: {err}"
"When opening the issue, please include the full traceback log from "
f"{str(error_file.absolute().resolve())!r}. Thank you!\n\n{err}"
)

with open(error_file, "w") as f:
f.write(traceback.format_exc())

sys.exit(1)


if __name__ == "__main__":
main()
11 changes: 10 additions & 1 deletion src/metatrain/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _add_eval_model_parser(subparser: argparse._SubParsersAction) -> None:
)
parser.add_argument(
"options",
type=OmegaConf.load,
type=str,
help="Eval options file to define a dataset for evaluation.",
)
parser.add_argument(
Expand All @@ -81,6 +81,15 @@ def _add_eval_model_parser(subparser: argparse._SubParsersAction) -> None:
)


def _prepare_eval_model_args(args: argparse.Namespace) -> None:
"""Prepare arguments for eval_model."""
args.options = OmegaConf.load(args.options)
args.model = metatensor.torch.atomistic.load_atomistic_model(
path=args.__dict__.pop("path"),
extensions_directory=args.__dict__.pop("extensions_directory"),
)


def _concatenate_tensormaps(
tensormap_dict_list: List[Dict[str, TensorMap]]
) -> Dict[str, TensorMap]:
Expand Down
12 changes: 11 additions & 1 deletion src/metatrain/cli/export.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import argparse
import importlib
import logging
from pathlib import Path
from typing import Any, Union

import torch

from ..utils.architectures import find_all_architectures
from ..utils.architectures import check_architecture_name, find_all_architectures
from ..utils.export import is_exported
from ..utils.io import check_suffix
from .formatter import CustomHelpFormatter
Expand Down Expand Up @@ -53,6 +54,15 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None:
)


def _prepare_export_model_args(args: argparse.Namespace) -> None:
"""Prepare arguments for export_model."""
architecture_name = args.__dict__.pop("architecture_name")
check_architecture_name(architecture_name)
architecture = importlib.import_module(f"metatrain.{architecture_name}")

args.model = architecture.__model__.load_checkpoint(args.__dict__.pop("path"))


def export_model(model: Any, output: Union[Path, str] = "exported-model.pt") -> None:
"""Export a trained model to allow it to make predictions.
Expand Down
13 changes: 12 additions & 1 deletion src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _add_train_model_parser(subparser: argparse._SubParsersAction) -> None:

parser.add_argument(
"options",
type=OmegaConf.load,
type=str,
help="Options file",
)
parser.add_argument(
Expand Down Expand Up @@ -85,6 +85,17 @@ def _add_train_model_parser(subparser: argparse._SubParsersAction) -> None:
)


def _prepare_train_model_args(args: argparse.Namespace) -> None:
"""Prepare arguments for train_model."""
args.options = OmegaConf.load(args.options)
# merge/override file options with command line options
override_options = args.__dict__.pop("override_options")
if override_options is None:
override_options = {}

args.options = OmegaConf.merge(args.options, override_options)


def train_model(
options: Union[DictConfig, Dict],
output: str = "model.pt",
Expand Down
6 changes: 3 additions & 3 deletions src/metatrain/utils/data/readers/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def read_targets(
fileformat=target["forces"]["file_format"],
dtype=dtype,
)
except KeyError:
except Exception:
logger.warning(
f"No Forces found in section {target_key!r}. "
"Continue without forces!"
Expand Down Expand Up @@ -232,7 +232,7 @@ def read_targets(
fileformat=target["stress"]["file_format"],
dtype=dtype,
)
except KeyError:
except Exception:
logger.warning(
f"No Stress found in section {target_key!r}. "
"Continue without stress!"
Expand All @@ -255,7 +255,7 @@ def read_targets(
fileformat=target["virial"]["file_format"],
dtype=dtype,
)
except KeyError:
except Exception:
logger.warning(
f"No Virial found in section {target_key!r}. "
"Continue without virial!"
Expand Down
9 changes: 4 additions & 5 deletions src/metatrain/utils/data/readers/targets/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def read_energy_ase(
blocks = []
for i_system, atoms in enumerate(frames):
if key not in atoms.info:
raise KeyError(
raise ValueError(
f"energy key {key!r} was not found in system {filename!r} at index "
f"{i_system}"
)
Expand Down Expand Up @@ -70,7 +70,7 @@ def read_forces_ase(
for i_system, atoms in enumerate(frames):

if key not in atoms.arrays:
raise KeyError(
raise ValueError(
f"forces key {key!r} was not found in system {filename!r} at index "
f"{i_system}"
)
Expand Down Expand Up @@ -150,8 +150,7 @@ def _read_virial_stress_ase(
:param is_virial: if target values are stored as stress or virials.
:param dtype: desired data type of returned tensor
:returns:
TensorMap containing the given information
:returns: TensorMap containing the given information
"""
frames = ase.io.read(filename, ":")

Expand All @@ -170,7 +169,7 @@ def _read_virial_stress_ase(
else:
target_name = "stress"

raise KeyError(
raise ValueError(
f"{target_name} key {key!r} was not found in system {filename!r} at "
f"index {i_system}"
)
Expand Down
7 changes: 4 additions & 3 deletions src/metatrain/utils/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ class ArchitectureError(Exception):

def __init__(self, exception):
super().__init__(
f"{exception}\n\nThe error wbove most likely originates from an "
"architecture. If you think this is a bug, please contact its maintainer "
"(see the architecture's documentation)."
f"{exception}\n\nThe error above most likely originates from an "
"architecture.\n\nIf you think this is a bug, please contact its "
"maintainer (see the architecture's documentation) and include the full "
"traceback error.log."
)
25 changes: 13 additions & 12 deletions src/metatrain/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,17 +137,17 @@ def _get_digits(value: float) -> Tuple[int, int]:

@contextlib.contextmanager
def setup_logging(
logobj: logging.Logger,
logfile: Optional[Union[str, Path]] = None,
log_obj: logging.Logger,
log_file: Optional[Union[str, Path]] = None,
level: int = logging.WARNING,
):
"""Create a logging environment for a given ``logobj``.
Extracted and adjusted from
github.com/MDAnalysis/mdacli/blob/main/src/mdacli/logger.py
:param logobj: A logging instance
:param logfile: Name of the log file
:param log_obj: A logging instance
:param log_file: Name of the log file
:param level: Set the root logger level to the specified level. If for example set
to :py:obj:`logging.DEBUG` detailed debug logs inludcing filename and function
name are displayed. For :py:obj:`logging.INFO` only the message logged from
Expand All @@ -166,26 +166,27 @@ def setup_logging(
stream_handler.setFormatter(formatter)
handlers.append(stream_handler)

if logfile:
logfile = check_suffix(filename=logfile, suffix=".log")
file_handler = logging.FileHandler(filename=str(logfile), encoding="utf-8")
if log_file:
log_file = check_suffix(filename=log_file, suffix=".log")
file_handler = logging.FileHandler(filename=str(log_file), encoding="utf-8")
file_handler.setFormatter(formatter)
handlers.append(file_handler)

logging.basicConfig(format=format, handlers=handlers, level=level, style="{")

if logfile:
logobj.info(f"This log is also available in {str(logfile)!r}.")
if log_file:
abs_path = str(Path(log_file).absolute().resolve())
log_obj.info(f"This log is also available at {abs_path!r}.")
else:
logobj.info("Logging to file is disabled.")
log_obj.info("Logging to file is disabled.")

for handler in handlers:
logobj.addHandler(handler)
log_obj.addHandler(handler)

yield

finally:
for handler in handlers:
handler.flush()
handler.close()
logobj.removeHandler(handler)
log_obj.removeHandler(handler)
35 changes: 35 additions & 0 deletions tests/cli/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Tests for argument parsing."""

import glob
import shutil
import subprocess
from pathlib import Path
from subprocess import CalledProcessError
from typing import List

import pytest
Expand Down Expand Up @@ -103,3 +105,36 @@ def get_completion_suggestions(partial_word: str) -> List[str]:
def test_subcommand_completion(partial_word, expected_completion):
"""Test that expected subcommand completion matches."""
assert set(get_completion_suggestions(partial_word)) == set(expected_completion)


@pytest.mark.parametrize("subcommand", ["train", "eval"])
def test_error(subcommand, capfd, monkeypatch, tmp_path):
"""Test expected display of errors to stdout and log files."""
monkeypatch.chdir(tmp_path)

command = ["mtt", subcommand]
if subcommand == "eval":
command += ["model.pt"]

command += ["foo.yaml"]

with pytest.raises(CalledProcessError):
subprocess.check_call(command)

stdout_log = capfd.readouterr().out

if subcommand == "train":
error_glob = glob.glob("outputs/*/*/error.log")
error_file = error_glob[0]
else:
error_file = "error.log"

error_file = str(Path(error_file).absolute().resolve())

with open(error_file) as f:
error_log = f.read()

print(error_file)
assert f"please include the full traceback log from {error_file!r}" in stdout_log
assert "No such file or directory" in stdout_log
assert "Traceback" in error_log
Loading

0 comments on commit 8c46dc5

Please sign in to comment.