Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve no-key error message in ase target parser #242

Merged
merged 2 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 30 additions & 42 deletions src/metatrain/__main__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
"""The main entry point for the metatrain command line interface."""

import argparse
import importlib
import logging
import os
import sys
import traceback
from datetime import datetime
from pathlib import Path

import metatensor.torch
from omegaconf import OmegaConf

from . import __version__
from .cli.eval import _add_eval_model_parser, eval_model
from .cli.export import _add_export_model_parser, export_model
from .cli.train import _add_train_model_parser, train_model
from .utils.architectures import check_architecture_name
from .cli.eval import _add_eval_model_parser, _prepare_eval_model_args, eval_model
from .cli.export import (
_add_export_model_parser,
_prepare_export_model_args,
export_model,
)
from .cli.train import _add_train_model_parser, _prepare_train_model_args, train_model
from .utils.logging import setup_logging


Expand Down Expand Up @@ -71,59 +70,48 @@
args = ap.parse_args()
callable = args.__dict__.pop("callable")
debug = args.__dict__.pop("debug")
logfile = None
log_file = None
error_file = Path("error.log")

if debug:
level = logging.DEBUG
else:
level = logging.INFO

if callable == "eval_model":
args.__dict__["model"] = metatensor.torch.atomistic.load_atomistic_model(
path=args.__dict__.pop("path"),
extensions_directory=args.__dict__.pop("extensions_directory"),
)
elif callable == "export_model":
architecture_name = args.__dict__.pop("architecture_name")
check_architecture_name(architecture_name)
architecture = importlib.import_module(f"metatrain.{architecture_name}")

args.__dict__["model"] = architecture.__model__.load_checkpoint(
args.__dict__.pop("path")
)
elif callable == "train_model":
# define and create `checkpoint_dir` based on current directory and date/time
if callable == "train_model":
# define and create `checkpoint_dir` based on current directory, date and time
checkpoint_dir = _datetime_output_path(now=datetime.now())
os.makedirs(checkpoint_dir)
args.__dict__["checkpoint_dir"] = checkpoint_dir

# save log to file
logfile = checkpoint_dir / "train.log"
args.checkpoint_dir = checkpoint_dir

# merge/override file options with command line options
override_options = args.__dict__.pop("override_options")
if override_options is None:
override_options = {}

args.options = OmegaConf.merge(args.options, override_options)
else:
raise ValueError("internal error when selecting a sub-command.")
log_file = checkpoint_dir / "train.log"
error_file = checkpoint_dir / error_file

with setup_logging(logger, logfile=logfile, level=level):
with setup_logging(logger, log_file=log_file, level=level):
try:
if callable == "eval_model":
_prepare_eval_model_args(args)
eval_model(**args.__dict__)
elif callable == "export_model":
_prepare_export_model_args(args)
export_model(**args.__dict__)
elif callable == "train_model":
_prepare_train_model_args(args)
train_model(**args.__dict__)
else:
raise ValueError("internal error when selecting a sub-command.")
raise ValueError("internal error when selecting a sub-command")

Check warning on line 102 in src/metatrain/__main__.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/__main__.py#L102

Added line #L102 was not covered by tests
except Exception as err:
if debug:
traceback.print_exc()
else:
sys.exit(str(err))
logging.error(
"If the error message below is unclear, please help us improve it by "
"opening an issue at https://github.com/lab-cosmo/metatrain/issues. "
"When opening the issue, please include the full traceback log from "
f"{str(error_file.absolute().resolve())!r}. Thank you!\n\n{err}"
)

with open(error_file, "w") as f:
f.write(traceback.format_exc())

sys.exit(1)


if __name__ == "__main__":
Expand Down
11 changes: 10 additions & 1 deletion src/metatrain/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _add_eval_model_parser(subparser: argparse._SubParsersAction) -> None:
)
parser.add_argument(
"options",
type=OmegaConf.load,
type=str,
help="Eval options file to define a dataset for evaluation.",
)
parser.add_argument(
Expand All @@ -81,6 +81,15 @@ def _add_eval_model_parser(subparser: argparse._SubParsersAction) -> None:
)


def _prepare_eval_model_args(args: argparse.Namespace) -> None:
"""Prepare arguments for eval_model."""
args.options = OmegaConf.load(args.options)
args.model = metatensor.torch.atomistic.load_atomistic_model(
path=args.__dict__.pop("path"),
extensions_directory=args.__dict__.pop("extensions_directory"),
)


def _concatenate_tensormaps(
tensormap_dict_list: List[Dict[str, TensorMap]]
) -> Dict[str, TensorMap]:
Expand Down
12 changes: 11 additions & 1 deletion src/metatrain/cli/export.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import argparse
import importlib
import logging
from pathlib import Path
from typing import Any, Union

import torch

from ..utils.architectures import find_all_architectures
from ..utils.architectures import check_architecture_name, find_all_architectures
from ..utils.export import is_exported
from ..utils.io import check_suffix
from .formatter import CustomHelpFormatter
Expand Down Expand Up @@ -53,6 +54,15 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None:
)


def _prepare_export_model_args(args: argparse.Namespace) -> None:
"""Prepare arguments for export_model."""
architecture_name = args.__dict__.pop("architecture_name")
check_architecture_name(architecture_name)
architecture = importlib.import_module(f"metatrain.{architecture_name}")

args.model = architecture.__model__.load_checkpoint(args.__dict__.pop("path"))


def export_model(model: Any, output: Union[Path, str] = "exported-model.pt") -> None:
"""Export a trained model to allow it to make predictions.

Expand Down
13 changes: 12 additions & 1 deletion src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _add_train_model_parser(subparser: argparse._SubParsersAction) -> None:

parser.add_argument(
"options",
type=OmegaConf.load,
type=str,
help="Options file",
)
parser.add_argument(
Expand Down Expand Up @@ -85,6 +85,17 @@ def _add_train_model_parser(subparser: argparse._SubParsersAction) -> None:
)


def _prepare_train_model_args(args: argparse.Namespace) -> None:
"""Prepare arguments for train_model."""
args.options = OmegaConf.load(args.options)
# merge/override file options with command line options
override_options = args.__dict__.pop("override_options")
if override_options is None:
override_options = {}

args.options = OmegaConf.merge(args.options, override_options)


def train_model(
options: Union[DictConfig, Dict],
output: str = "model.pt",
Expand Down
6 changes: 3 additions & 3 deletions src/metatrain/utils/data/readers/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def read_targets(
fileformat=target["forces"]["file_format"],
dtype=dtype,
)
except KeyError:
except Exception:
logger.warning(
f"No Forces found in section {target_key!r}. "
"Continue without forces!"
Expand Down Expand Up @@ -232,7 +232,7 @@ def read_targets(
fileformat=target["stress"]["file_format"],
dtype=dtype,
)
except KeyError:
except Exception:
logger.warning(
f"No Stress found in section {target_key!r}. "
"Continue without stress!"
Expand All @@ -255,7 +255,7 @@ def read_targets(
fileformat=target["virial"]["file_format"],
dtype=dtype,
)
except KeyError:
except Exception:
logger.warning(
f"No Virial found in section {target_key!r}. "
"Continue without virial!"
Expand Down
26 changes: 24 additions & 2 deletions src/metatrain/utils/data/readers/targets/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def read_energy_ase(

blocks = []
for i_system, atoms in enumerate(frames):
if key not in atoms.info:
raise ValueError(
f"energy key {key!r} was not found in system {filename!r} at index "
f"{i_system}"
)

values = torch.tensor([[atoms.info[key]]], dtype=dtype)
samples = Labels(["system"], torch.tensor([[i_system]]))

Expand Down Expand Up @@ -62,6 +68,13 @@ def read_forces_ase(

blocks = []
for i_system, atoms in enumerate(frames):

if key not in atoms.arrays:
raise ValueError(
f"forces key {key!r} was not found in system {filename!r} at index "
f"{i_system}"
)

# We store forces as positions gradients which means we invert the sign
values = -torch.tensor(atoms.arrays[key], dtype=dtype)
values = values.reshape(-1, 3, 1)
Expand Down Expand Up @@ -137,8 +150,7 @@ def _read_virial_stress_ase(
:param is_virial: if target values are stored as stress or virials.
:param dtype: desired data type of returned tensor

:returns:
TensorMap containing the given information
:returns: TensorMap containing the given information
"""
frames = ase.io.read(filename, ":")

Expand All @@ -151,6 +163,16 @@ def _read_virial_stress_ase(

blocks = []
for i_system, atoms in enumerate(frames):
if key not in atoms.info:
if is_virial:
target_name = "virial"
else:
target_name = "stress"

raise ValueError(
f"{target_name} key {key!r} was not found in system {filename!r} at "
f"index {i_system}"
)

values = torch.tensor(atoms.info[key].tolist(), dtype=dtype)

Expand Down
7 changes: 4 additions & 3 deletions src/metatrain/utils/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ class ArchitectureError(Exception):

def __init__(self, exception):
super().__init__(
"The error below most likely originates from an architecture. If you think "
"this is a bug, please contact its maintainer (see the architecture's "
f"documentation).\n\n{exception}"
f"{exception}\n\nThe error above most likely originates from an "
"architecture.\n\nIf you think this is a bug, please contact its "
"maintainer (see the architecture's documentation) and include the full "
"traceback error.log."
)
25 changes: 13 additions & 12 deletions src/metatrain/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,17 +144,17 @@ def _get_digits(value: float) -> Tuple[int, int]:

@contextlib.contextmanager
def setup_logging(
logobj: logging.Logger,
logfile: Optional[Union[str, Path]] = None,
log_obj: logging.Logger,
log_file: Optional[Union[str, Path]] = None,
level: int = logging.WARNING,
):
"""Create a logging environment for a given ``logobj``.

Extracted and adjusted from
github.com/MDAnalysis/mdacli/blob/main/src/mdacli/logger.py

:param logobj: A logging instance
:param logfile: Name of the log file
:param log_obj: A logging instance
:param log_file: Name of the log file
:param level: Set the root logger level to the specified level. If for example set
to :py:obj:`logging.DEBUG` detailed debug logs inludcing filename and function
name are displayed. For :py:obj:`logging.INFO` only the message logged from
Expand All @@ -173,26 +173,27 @@ def setup_logging(
stream_handler.setFormatter(formatter)
handlers.append(stream_handler)

if logfile:
logfile = check_suffix(filename=logfile, suffix=".log")
file_handler = logging.FileHandler(filename=str(logfile), encoding="utf-8")
if log_file:
log_file = check_suffix(filename=log_file, suffix=".log")
file_handler = logging.FileHandler(filename=str(log_file), encoding="utf-8")
file_handler.setFormatter(formatter)
handlers.append(file_handler)

logging.basicConfig(format=format, handlers=handlers, level=level, style="{")

if logfile:
logobj.info(f"This log is also available in {str(logfile)!r}.")
if log_file:
abs_path = str(Path(log_file).absolute().resolve())
log_obj.info(f"This log is also available at {abs_path!r}.")
else:
logobj.info("Logging to file is disabled.")
log_obj.info("Logging to file is disabled.")

for handler in handlers:
logobj.addHandler(handler)
log_obj.addHandler(handler)

yield

finally:
for handler in handlers:
handler.flush()
handler.close()
logobj.removeHandler(handler)
log_obj.removeHandler(handler)
35 changes: 35 additions & 0 deletions tests/cli/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Tests for argument parsing."""

import glob
import shutil
import subprocess
from pathlib import Path
from subprocess import CalledProcessError
from typing import List

import pytest
Expand Down Expand Up @@ -103,3 +105,36 @@ def get_completion_suggestions(partial_word: str) -> List[str]:
def test_subcommand_completion(partial_word, expected_completion):
"""Test that expected subcommand completion matches."""
assert set(get_completion_suggestions(partial_word)) == set(expected_completion)


@pytest.mark.parametrize("subcommand", ["train", "eval"])
def test_error(subcommand, capfd, monkeypatch, tmp_path):
"""Test expected display of errors to stdout and log files."""
monkeypatch.chdir(tmp_path)

command = ["mtt", subcommand]
if subcommand == "eval":
command += ["model.pt"]

command += ["foo.yaml"]

with pytest.raises(CalledProcessError):
subprocess.check_call(command)

stdout_log = capfd.readouterr().out

if subcommand == "train":
error_glob = glob.glob("outputs/*/*/error.log")
error_file = error_glob[0]
else:
error_file = "error.log"

error_file = str(Path(error_file).absolute().resolve())

with open(error_file) as f:
error_log = f.read()

print(error_file)
assert f"please include the full traceback log from {error_file!r}" in stdout_log
assert "No such file or directory" in stdout_log
assert "Traceback" in error_log
Loading
Loading