Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error when architecture deps are missing #350

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/metatrain/cli/export.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import argparse
import importlib
import logging
from pathlib import Path
from typing import Any, Union

import torch

from ..utils.architectures import check_architecture_name, find_all_architectures
from ..utils.architectures import find_all_architectures, import_architecture
from ..utils.export import is_exported
from ..utils.io import check_file_extension
from .formatter import CustomHelpFormatter
Expand Down Expand Up @@ -57,8 +56,7 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None:
def _prepare_export_model_args(args: argparse.Namespace) -> None:
"""Prepare arguments for export_model."""
architecture_name = args.__dict__.pop("architecture_name")
check_architecture_name(architecture_name)
architecture = importlib.import_module(f"metatrain.{architecture_name}")
architecture = import_architecture(architecture_name)

args.model = architecture.__model__.load_checkpoint(args.__dict__.pop("path"))

Expand Down
9 changes: 6 additions & 3 deletions src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import importlib
import itertools
import json
import logging
Expand All @@ -14,7 +13,11 @@
from omegaconf import DictConfig, OmegaConf

from .. import PACKAGE_ROOT
from ..utils.architectures import check_architecture_options, get_default_hypers
from ..utils.architectures import (
check_architecture_options,
get_default_hypers,
import_architecture,
)
from ..utils.data import (
DatasetInfo,
TargetInfoDict,
Expand Down Expand Up @@ -135,7 +138,7 @@ def train_model(
check_architecture_options(
name=architecture_name, options=OmegaConf.to_container(options["architecture"])
)
architecture = importlib.import_module(f"metatrain.{architecture_name}")
architecture = import_architecture(architecture_name)

logger.info(f"Running training for {architecture_name!r} architecture")

Expand Down
25 changes: 25 additions & 0 deletions src/metatrain/utils/architectures.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import difflib
import importlib
import json
import logging
from importlib.util import find_spec
Expand Down Expand Up @@ -110,6 +111,30 @@ def get_architecture_name(path: Union[str, Path]) -> str:
return name


def import_architecture(name: str):
"""Import an architecture.

:param name: name of the architecture
:raises ImportError: if the architecture dependencies are not met
"""
check_architecture_name(name)
try:
return importlib.import_module(f"metatrain.{name}")
except ImportError as err:
# consistent name with pyproject.toml's `optional-dependencies` section
name_for_deps = name
if "experimental." in name or "deprecated." in name:
name_for_deps = ".".join(name.split(".")[1:])

name_for_deps = name_for_deps.replace("_", "-")

raise ImportError(
f"Trying to import '{name}' but architecture dependencies "
f"seem not be installed. \n"
f"Try to install them with `pip install .[{name_for_deps}]`"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"Try to install them with `pip install .[{name_for_deps}]`"
f"Try to install them with `pip install metatrain[{name_for_deps}]`"

) from err


def get_architecture_path(name: str) -> Path:
"""Return the relative path to the architeture directory.

Expand Down
5 changes: 2 additions & 3 deletions src/metatrain/utils/omegaconf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import json
from typing import Any, Union

Expand All @@ -7,13 +6,13 @@
from omegaconf.basecontainer import BaseContainer

from .. import PACKAGE_ROOT, RANDOM_SEED
from .architectures import import_architecture
from .devices import pick_devices
from .jsonschema import validate


def _get_architecture_model(conf: BaseContainer) -> Any:
architecture_name = conf["architecture"]["name"]
architecture = importlib.import_module(f"metatrain.{architecture_name}")
architecture = import_architecture(conf["architecture"]["name"])
return architecture.__model__


Expand Down
30 changes: 30 additions & 0 deletions tests/utils/test_architectures.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
from pathlib import Path

import pytest
Expand All @@ -11,9 +12,14 @@
get_architecture_name,
get_architecture_path,
get_default_hypers,
import_architecture,
)


def is_None(*args, **kwargs) -> None:
return None


def test_find_all_architectures():
all_arches = find_all_architectures()
assert len(all_arches) == 4
Expand Down Expand Up @@ -116,3 +122,27 @@ def test_check_architecture_options_error_raise():
match = r"Unrecognized options \('num_epochxxx' was unexpected\)"
with pytest.raises(ValidationError, match=match):
check_architecture_options(name=name, options=options)


def test_import_architecture():
name = "experimental.soap_bpnn"
architecture_ref = importlib.import_module(f"metatrain.{name}")
assert import_architecture(name) == architecture_ref


def test_import_architecture_erro(monkeypatch):
# `check_architecture_name` is called inside `import_architecture` and we have to
# disble the check to allow passing our "unknown" fancy-model below.
monkeypatch.setattr(
"metatrain.utils.architectures.check_architecture_name", is_None
)

name = "experimental.fancy_model"
name_for_deps = "fancy-model"

match = (
rf"Trying to import '{name}' but architecture dependencies seem not be "
rf"installed. \nTry to install them with `pip install .\[{name_for_deps}\]`"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
rf"installed. \nTry to install them with `pip install .\[{name_for_deps}\]`"
rf"installed. \nTry to install them with `pip install metatrain\[{name_for_deps}\]`"

)
with pytest.raises(ImportError, match=match):
import_architecture(name)