From 85cffa80b13d7a0e61d73d401d36b3be81fdfc82 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Mon, 30 Sep 2024 16:51:47 +0200 Subject: [PATCH] allow exporting models on remote locations --- src/metatrain/cli/export.py | 21 ++++++++++---------- src/metatrain/utils/io.py | 39 ++++++++++++++++++++++++++++++++++++- 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/src/metatrain/cli/export.py b/src/metatrain/cli/export.py index d63f11f3e..22a8ae3c5 100644 --- a/src/metatrain/cli/export.py +++ b/src/metatrain/cli/export.py @@ -1,14 +1,13 @@ import argparse -import importlib import logging from pathlib import Path from typing import Any, Union import torch -from ..utils.architectures import check_architecture_name, find_all_architectures +from ..utils.architectures import find_all_architectures from ..utils.export import is_exported -from ..utils.io import check_file_extension +from ..utils.io import check_file_extension, load_model from .formatter import CustomHelpFormatter @@ -41,7 +40,10 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None: parser.add_argument( "path", type=str, - help="Saved model which should be exported", + help=( + "Saved model which should be exported. Path can be either a URL or a " + "local file." + ), ) parser.add_argument( "-o", @@ -56,15 +58,14 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None: def _prepare_export_model_args(args: argparse.Namespace) -> None: """Prepare arguments for export_model.""" - architecture_name = args.__dict__.pop("architecture_name") - check_architecture_name(architecture_name) - architecture = importlib.import_module(f"metatrain.{architecture_name}") - - args.model = architecture.__model__.load_checkpoint(args.__dict__.pop("path")) + args.model = load_model( + architecture_name=args.__dict__.pop("architecture_name"), + path=args.__dict__.pop("path"), + ) def export_model(model: Any, output: Union[Path, str] = "exported-model.pt") -> None: - """Export a trained model to allow it to make predictions. + """Export a trained model allowing it to make predictions. This includes predictions within molecular simulation engines. Exported models will be saved with a ``.pt`` file ending. If ``path`` does not end with this file diff --git a/src/metatrain/utils/io.py b/src/metatrain/utils/io.py index a11cbdd0e..264cce7c4 100644 --- a/src/metatrain/utils/io.py +++ b/src/metatrain/utils/io.py @@ -1,6 +1,14 @@ +import importlib import warnings from pathlib import Path -from typing import Union +from typing import Any, Union +from urllib.parse import urlparse +from urllib.request import urlretrieve + +import torch +from metatensor.torch.atomistic import check_atomistic_model + +from .architectures import check_architecture_name def check_file_extension( @@ -30,3 +38,32 @@ def check_file_extension( return str(path_filename) else: return path_filename + + +def load_model(architecture_name: str, path: Union[str, Path]) -> Any: + """Loads a module from an URL or a local file. + + :param architecture_name: TODO + :param path: TODO + """ + if Path(path).suffix in [".yaml", ".yml"]: + raise ValueError(f"path '{path}' seems to be option file and not a model") + + if urlparse(str(path)).scheme: + path, _ = urlretrieve(str(path)) + + try: + check_atomistic_model(str(path)) + except ValueError: + check_architecture_name(architecture_name) + architecture = importlib.import_module(f"metatrain.{architecture_name}") + + try: + return architecture.__model__.load_checkpoint(str(path)) + except RuntimeError: + raise ValueError( + f"path '{path}' is not a valid a local or a remote model file for " + "the {architecture_name} architecture" + ) + else: + return torch.jit.load(str(path))