Skip to content

Commit

Permalink
allow exporting models on remote locations
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Oct 3, 2024
1 parent 96f2c74 commit 85cffa8
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 11 deletions.
21 changes: 11 additions & 10 deletions src/metatrain/cli/export.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import argparse
import importlib
import logging
from pathlib import Path
from typing import Any, Union

import torch

from ..utils.architectures import check_architecture_name, find_all_architectures
from ..utils.architectures import find_all_architectures
from ..utils.export import is_exported
from ..utils.io import check_file_extension
from ..utils.io import check_file_extension, load_model
from .formatter import CustomHelpFormatter


Expand Down Expand Up @@ -41,7 +40,10 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None:
parser.add_argument(
"path",
type=str,
help="Saved model which should be exported",
help=(
"Saved model which should be exported. Path can be either a URL or a "
"local file."
),
)
parser.add_argument(
"-o",
Expand All @@ -56,15 +58,14 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None:

def _prepare_export_model_args(args: argparse.Namespace) -> None:
"""Prepare arguments for export_model."""
architecture_name = args.__dict__.pop("architecture_name")
check_architecture_name(architecture_name)
architecture = importlib.import_module(f"metatrain.{architecture_name}")

args.model = architecture.__model__.load_checkpoint(args.__dict__.pop("path"))
args.model = load_model(
architecture_name=args.__dict__.pop("architecture_name"),
path=args.__dict__.pop("path"),
)


def export_model(model: Any, output: Union[Path, str] = "exported-model.pt") -> None:
"""Export a trained model to allow it to make predictions.
"""Export a trained model allowing it to make predictions.
This includes predictions within molecular simulation engines. Exported models will
be saved with a ``.pt`` file ending. If ``path`` does not end with this file
Expand Down
39 changes: 38 additions & 1 deletion src/metatrain/utils/io.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import importlib
import warnings
from pathlib import Path
from typing import Union
from typing import Any, Union
from urllib.parse import urlparse
from urllib.request import urlretrieve

import torch
from metatensor.torch.atomistic import check_atomistic_model

from .architectures import check_architecture_name


def check_file_extension(
Expand Down Expand Up @@ -30,3 +38,32 @@ def check_file_extension(
return str(path_filename)
else:
return path_filename


def load_model(architecture_name: str, path: Union[str, Path]) -> Any:
"""Loads a module from an URL or a local file.
:param architecture_name: TODO
:param path: TODO
"""
if Path(path).suffix in [".yaml", ".yml"]:
raise ValueError(f"path '{path}' seems to be option file and not a model")

Check warning on line 50 in src/metatrain/utils/io.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/io.py#L50

Added line #L50 was not covered by tests

if urlparse(str(path)).scheme:
path, _ = urlretrieve(str(path))

Check warning on line 53 in src/metatrain/utils/io.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/io.py#L53

Added line #L53 was not covered by tests

try:
check_atomistic_model(str(path))
except ValueError:
check_architecture_name(architecture_name)
architecture = importlib.import_module(f"metatrain.{architecture_name}")

try:
return architecture.__model__.load_checkpoint(str(path))
except RuntimeError:
raise ValueError(

Check warning on line 64 in src/metatrain/utils/io.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/io.py#L63-L64

Added lines #L63 - L64 were not covered by tests
f"path '{path}' is not a valid a local or a remote model file for "
"the {architecture_name} architecture"
)
else:
return torch.jit.load(str(path))

Check warning on line 69 in src/metatrain/utils/io.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/io.py#L69

Added line #L69 was not covered by tests

0 comments on commit 85cffa8

Please sign in to comment.