Skip to content

Commit

Permalink
allow exporting models on remote locations
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Oct 14, 2024
1 parent 161fd56 commit d2ef7fa
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 33 deletions.
22 changes: 16 additions & 6 deletions docs/src/getting-started/checkpoints.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,30 @@ or
mtt train options.yaml -c model.ckpt
Checkpoints can also be turned into exported models using the ``export`` sub-command.
The command requires the *architecture name* and the saved checkpoint *path* as
positional arguments

.. code-block:: bash
mtt export model.ckpt -o model.pt
mtt export experimental.soap_bpnn model.ckpt -o model.pt
or

.. code-block:: bash
mtt export model.ckpt --output model.pt
mtt export experimental.soap_bpnn model.ckpt --output model.pt
For a export of distribution of models the ``export`` command also supports parsing
models from remote locations. To export a remote model you can provide a URL instead of
a file path.

.. code-block:: bash
mtt export experimental.soap_bpnn https://my.url.com/model.ckpt --output model.pt
Keep in mind that a checkpoint (``.ckpt``) is only a temporary file, which can have
several dependencies and may become unusable if the corresponding architecture is
updated. In constrast, exported models (``.pt``) act as standalone files.
For long-term usage, you should export your model! Exporting a model is also necessary
if you want to use it in other frameworks, especially in molecular simulations
(see the :ref:`tutorials`).
updated. In constrast, exported models (``.pt``) act as standalone files. For long-term
usage, you should export your model! Exporting a model is also necessary if you want to
use it in other frameworks, especially in molecular simulations (see the
:ref:`tutorials`).
44 changes: 22 additions & 22 deletions src/metatrain/cli/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
from pathlib import Path
from typing import Any, Union

import torch

from ..utils.architectures import find_all_architectures, import_architecture
from ..utils.architectures import find_all_architectures
from ..utils.export import is_exported
from ..utils.io import check_file_extension
from ..utils.io import check_file_extension, load_model
from .formatter import CustomHelpFormatter


Expand Down Expand Up @@ -40,7 +38,10 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None:
parser.add_argument(
"path",
type=str,
help="Saved model which should be exported",
help=(
"Saved model which should be exported. Path can be either a URL or a "
"local file."
),
)
parser.add_argument(
"-o",
Expand All @@ -55,14 +56,14 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None:

def _prepare_export_model_args(args: argparse.Namespace) -> None:
"""Prepare arguments for export_model."""
architecture_name = args.__dict__.pop("architecture_name")
architecture = import_architecture(architecture_name)

args.model = architecture.__model__.load_checkpoint(args.__dict__.pop("path"))
args.model = load_model(
architecture_name=args.__dict__.pop("architecture_name"),
path=args.__dict__.pop("path"),
)


def export_model(model: Any, output: Union[Path, str] = "exported-model.pt") -> None:
"""Export a trained model to allow it to make predictions.
"""Export a trained model allowing it to make predictions.
This includes predictions within molecular simulation engines. Exported models will
be saved with a ``.pt`` file ending. If ``path`` does not end with this file
Expand All @@ -71,16 +72,15 @@ def export_model(model: Any, output: Union[Path, str] = "exported-model.pt") ->
:param model: model to be exported
:param output: path to save the exported model
"""
path = str(check_file_extension(filename=output, extension=".pt"))
path = str(
Path(check_file_extension(filename=output, extension=".pt"))
.absolute()
.resolve()
)
extensions_path = str(Path("extensions/").absolute().resolve())

if is_exported(model):
logger.info(f"The model is already exported. Saving it to `{path}`.")
torch.jit.save(model, path)
else:
extensions_path = "extensions/"
logger.info(
f"Exporting model to '{path}' and extensions to '{extensions_path}'"
)
mts_atomistic_model = model.export()
mts_atomistic_model.save(path, collect_extensions=extensions_path)
logger.info("Model exported successfully")
if not is_exported(model):
model = model.export()

model.save(path, collect_extensions=extensions_path)
logger.info(f"Model exported to '{path}' and extensions to '{extensions_path}'")
4 changes: 4 additions & 0 deletions src/metatrain/utils/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def is_exported(model: Any) -> bool:
:param model: The model to check
:return: :py:obj:`True` if the ``model`` has been exported, :py:obj:`False`
otherwise.
.. seealso::
:py:func:`utils.io.is_exported_file <metatrain.utils.io.is_exported_file>`
to verify if a saved model on disk is already exported.
"""
# If the model is saved and loaded again, its type is RecursiveScriptModule
if type(model) in [
Expand Down
57 changes: 56 additions & 1 deletion src/metatrain/utils/io.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import warnings
from pathlib import Path
from typing import Union
from typing import Any, Union
from urllib.parse import urlparse
from urllib.request import urlretrieve

import torch
from metatensor.torch.atomistic import check_atomistic_model

from .architectures import import_architecture


def check_file_extension(
Expand Down Expand Up @@ -30,3 +37,51 @@ def check_file_extension(
return str(path_filename)
else:
return path_filename


def is_exported_file(path: str) -> bool:
"""Check if a saved model file has been exported to a MetatensorAtomisticModel.
:param path: model path
:return: :py:obj:`True` if the ``model`` has been exported, :py:obj:`False`
otherwise.
.. seealso::
:py:func:`utils.export.is_exported <metatrain.utils.export.is_exported>` to
verify if an already loaded model is exported.
"""
try:
check_atomistic_model(str(path))
return True
except ValueError:
return False


def load_model(architecture_name: str, path: Union[str, Path]) -> Any:
"""Loads a module from an URL or a local file.
:param name: name of the architecture
:param path: local or remote path to a model. For supported URL schemes see
:py:class`urllib.request`
:raises ValueError: if ``path`` is a YAML option file and no model
:raises ValueError: if the checkpoint saved in ``path`` does not math the given
``architecture_name``
"""
if Path(path).suffix in [".yaml", ".yml"]:
raise ValueError(f"path '{path}' seems to be a YAML option file and no model")

if urlparse(str(path)).scheme:
path, _ = urlretrieve(str(path))

if is_exported_file(str(path)):
return torch.jit.load(str(path))
else: # model is a checkpoint
architecture = import_architecture(architecture_name)

try:
return architecture.__model__.load_checkpoint(str(path))
except Exception as err:
raise ValueError(
f"path '{path}' is not a valid model file for the {architecture_name} "
"architecture"
) from err
7 changes: 6 additions & 1 deletion tests/cli/test_export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import glob
import logging
import subprocess
from pathlib import Path

Expand All @@ -20,9 +21,10 @@


@pytest.mark.parametrize("path", [Path("exported.pt"), "exported.pt"])
def test_export(monkeypatch, tmp_path, path):
def test_export(monkeypatch, tmp_path, path, caplog):
"""Tests the export_model function."""
monkeypatch.chdir(tmp_path)
caplog.set_level(logging.INFO)

dataset_info = DatasetInfo(
length_unit="angstrom",
Expand All @@ -42,6 +44,9 @@ def test_export(monkeypatch, tmp_path, path):

assert Path(path).is_file()

# Test log message
assert "Model exported to" in caplog.text


@pytest.mark.parametrize("output", [None, "exported.pt"])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
Expand Down
60 changes: 59 additions & 1 deletion tests/utils/test_io.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from pathlib import Path

import pytest
from torch.jit._script import RecursiveScriptModule

from metatrain.utils.io import check_file_extension
from metatrain.experimental.soap_bpnn.model import SoapBpnn
from metatrain.utils.io import check_file_extension, is_exported_file, load_model

from . import RESOURCES_PATH


def is_None(*args, **kwargs) -> None:
return None


@pytest.mark.parametrize("filename", ["example.txt", Path("example.txt")])
Expand All @@ -21,3 +29,53 @@ def test_warning_on_missing_suffix(filename):

assert str(result) == "example.txt"
assert isinstance(result, type(filename))


def test_is_exported_file():
assert is_exported_file(RESOURCES_PATH / "model-32-bit.pt")
assert not is_exported_file(RESOURCES_PATH / "model-32-bit.ckpt")


@pytest.mark.parametrize(
"path",
[
RESOURCES_PATH / "model-32-bit.ckpt",
str(RESOURCES_PATH / "model-32-bit.ckpt"),
f"file:{str(RESOURCES_PATH / 'model-32-bit.ckpt')}",
],
)
def test_load_model_checkpoint(path):
model = load_model("experimental.soap_bpnn", path)
assert type(model) is SoapBpnn


@pytest.mark.parametrize(
"path",
[
RESOURCES_PATH / "model-32-bit.pt",
str(RESOURCES_PATH / "model-32-bit.pt"),
f"file:{str(RESOURCES_PATH / 'model-32-bit.pt')}",
],
)
def test_load_model_exported(path):
model = load_model("experimental.soap_bpnn", RESOURCES_PATH / "model-32-bit.pt")
assert type(model) is RecursiveScriptModule


@pytest.mark.parametrize("suffix", [".yml", ".yaml"])
def test_load_model_yaml(suffix):
match = f"path 'foo{suffix}' seems to be a YAML option file and no model"
with pytest.raises(ValueError, match=match):
load_model("experimental.soap_bpnn", f"foo{suffix}")


def test_load_model_unknown_model():
architecture_name = "experimental.pet"
path = RESOURCES_PATH / "model-32-bit.ckpt"

match = (
f"path '{path}' is not a valid model file for the {architecture_name} "
"architecture"
)
with pytest.raises(ValueError, match=match):
load_model(architecture_name, path)
6 changes: 4 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ deps =
pytest-cov
pytest-xdist
changedir = tests
extras = soap-bpnn # this model is used in the package tests
extras = # architectures used in the package tests
soap-bpnn
pet
allowlist_externals = bash
commands_pre = bash {toxinidir}/tests/resources/generate-outputs.sh
commands =
Expand Down Expand Up @@ -92,7 +94,7 @@ description = Run SOAP-BPNN tests with pytest
passenv = *
deps =
pytest
extras = soap-bpnn
extras =soap-bpnn
changedir = src/metatrain/experimental/soap_bpnn/tests/
commands =
pytest {posargs}
Expand Down

0 comments on commit d2ef7fa

Please sign in to comment.