From 161fd564b6d668da0b838736c65b79df5ba69bb9 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Fri, 4 Oct 2024 18:10:15 +0200 Subject: [PATCH] improve error when architecture deps are missing (#350) --- src/metatrain/cli/export.py | 6 ++---- src/metatrain/cli/train.py | 9 ++++++--- src/metatrain/utils/architectures.py | 25 +++++++++++++++++++++++ src/metatrain/utils/omegaconf.py | 5 ++--- tests/utils/test_architectures.py | 30 ++++++++++++++++++++++++++++ 5 files changed, 65 insertions(+), 10 deletions(-) diff --git a/src/metatrain/cli/export.py b/src/metatrain/cli/export.py index d63f11f3e..378d93145 100644 --- a/src/metatrain/cli/export.py +++ b/src/metatrain/cli/export.py @@ -1,12 +1,11 @@ import argparse -import importlib import logging from pathlib import Path from typing import Any, Union import torch -from ..utils.architectures import check_architecture_name, find_all_architectures +from ..utils.architectures import find_all_architectures, import_architecture from ..utils.export import is_exported from ..utils.io import check_file_extension from .formatter import CustomHelpFormatter @@ -57,8 +56,7 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None: def _prepare_export_model_args(args: argparse.Namespace) -> None: """Prepare arguments for export_model.""" architecture_name = args.__dict__.pop("architecture_name") - check_architecture_name(architecture_name) - architecture = importlib.import_module(f"metatrain.{architecture_name}") + architecture = import_architecture(architecture_name) args.model = architecture.__model__.load_checkpoint(args.__dict__.pop("path")) diff --git a/src/metatrain/cli/train.py b/src/metatrain/cli/train.py index ffc880f73..5523b6ef4 100644 --- a/src/metatrain/cli/train.py +++ b/src/metatrain/cli/train.py @@ -1,5 +1,4 @@ import argparse -import importlib import itertools import json import logging @@ -14,7 +13,11 @@ from omegaconf import DictConfig, OmegaConf from .. import PACKAGE_ROOT -from ..utils.architectures import check_architecture_options, get_default_hypers +from ..utils.architectures import ( + check_architecture_options, + get_default_hypers, + import_architecture, +) from ..utils.data import ( DatasetInfo, TargetInfoDict, @@ -135,7 +138,7 @@ def train_model( check_architecture_options( name=architecture_name, options=OmegaConf.to_container(options["architecture"]) ) - architecture = importlib.import_module(f"metatrain.{architecture_name}") + architecture = import_architecture(architecture_name) logger.info(f"Running training for {architecture_name!r} architecture") diff --git a/src/metatrain/utils/architectures.py b/src/metatrain/utils/architectures.py index 6d3d66966..33420f99a 100644 --- a/src/metatrain/utils/architectures.py +++ b/src/metatrain/utils/architectures.py @@ -1,4 +1,5 @@ import difflib +import importlib import json import logging from importlib.util import find_spec @@ -110,6 +111,30 @@ def get_architecture_name(path: Union[str, Path]) -> str: return name +def import_architecture(name: str): + """Import an architecture. + + :param name: name of the architecture + :raises ImportError: if the architecture dependencies are not met + """ + check_architecture_name(name) + try: + return importlib.import_module(f"metatrain.{name}") + except ImportError as err: + # consistent name with pyproject.toml's `optional-dependencies` section + name_for_deps = name + if "experimental." in name or "deprecated." in name: + name_for_deps = ".".join(name.split(".")[1:]) + + name_for_deps = name_for_deps.replace("_", "-") + + raise ImportError( + f"Trying to import '{name}' but architecture dependencies " + f"seem not be installed. \n" + f"Try to install them with `pip install .[{name_for_deps}]`" + ) from err + + def get_architecture_path(name: str) -> Path: """Return the relative path to the architeture directory. diff --git a/src/metatrain/utils/omegaconf.py b/src/metatrain/utils/omegaconf.py index e156bc132..3a223df5d 100644 --- a/src/metatrain/utils/omegaconf.py +++ b/src/metatrain/utils/omegaconf.py @@ -1,4 +1,3 @@ -import importlib import json from typing import Any, Union @@ -7,13 +6,13 @@ from omegaconf.basecontainer import BaseContainer from .. import PACKAGE_ROOT, RANDOM_SEED +from .architectures import import_architecture from .devices import pick_devices from .jsonschema import validate def _get_architecture_model(conf: BaseContainer) -> Any: - architecture_name = conf["architecture"]["name"] - architecture = importlib.import_module(f"metatrain.{architecture_name}") + architecture = import_architecture(conf["architecture"]["name"]) return architecture.__model__ diff --git a/tests/utils/test_architectures.py b/tests/utils/test_architectures.py index a83db97d4..3bb788392 100644 --- a/tests/utils/test_architectures.py +++ b/tests/utils/test_architectures.py @@ -1,3 +1,4 @@ +import importlib from pathlib import Path import pytest @@ -11,9 +12,14 @@ get_architecture_name, get_architecture_path, get_default_hypers, + import_architecture, ) +def is_None(*args, **kwargs) -> None: + return None + + def test_find_all_architectures(): all_arches = find_all_architectures() assert len(all_arches) == 4 @@ -116,3 +122,27 @@ def test_check_architecture_options_error_raise(): match = r"Unrecognized options \('num_epochxxx' was unexpected\)" with pytest.raises(ValidationError, match=match): check_architecture_options(name=name, options=options) + + +def test_import_architecture(): + name = "experimental.soap_bpnn" + architecture_ref = importlib.import_module(f"metatrain.{name}") + assert import_architecture(name) == architecture_ref + + +def test_import_architecture_erro(monkeypatch): + # `check_architecture_name` is called inside `import_architecture` and we have to + # disble the check to allow passing our "unknown" fancy-model below. + monkeypatch.setattr( + "metatrain.utils.architectures.check_architecture_name", is_None + ) + + name = "experimental.fancy_model" + name_for_deps = "fancy-model" + + match = ( + rf"Trying to import '{name}' but architecture dependencies seem not be " + rf"installed. \nTry to install them with `pip install .\[{name_for_deps}\]`" + ) + with pytest.raises(ImportError, match=match): + import_architecture(name)