From c94be49612daebfca0fc44090a870a3e2430c5b4 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Mon, 30 Sep 2024 16:51:47 +0200 Subject: [PATCH] allow exporting models on remote locations --- docs/src/getting-started/checkpoints.rst | 22 ++++++--- src/metatrain/cli/export.py | 19 ++++---- src/metatrain/utils/export.py | 4 ++ src/metatrain/utils/io.py | 57 +++++++++++++++++++++- tests/utils/test_io.py | 60 +++++++++++++++++++++++- tox.ini | 6 ++- 6 files changed, 150 insertions(+), 18 deletions(-) diff --git a/docs/src/getting-started/checkpoints.rst b/docs/src/getting-started/checkpoints.rst index 7e4a30e93..6c95afe8f 100644 --- a/docs/src/getting-started/checkpoints.rst +++ b/docs/src/getting-started/checkpoints.rst @@ -25,20 +25,30 @@ or mtt train options.yaml -c model.ckpt Checkpoints can also be turned into exported models using the ``export`` sub-command. +The command requires the `architecture name` and the saved checkpoint `path` as +positional arguments .. code-block:: bash - mtt export model.ckpt -o model.pt + mtt export experimental.soap_bpnn model.ckpt -o model.pt or .. code-block:: bash - mtt export model.ckpt --output model.pt + mtt export experimental.soap_bpnn model.ckpt --output model.pt + +For a export of distribution of models the ``export`` command also supports parsing +models from remote locations. To export a remote model you can provide a URL instead of +a file path. + +.. code-block:: bash + + mtt export experimental.soap_bpnn https://my.url.com/model.ckpt --output model.pt Keep in mind that a checkpoint (``.ckpt``) is only a temporary file, which can have several dependencies and may become unusable if the corresponding architecture is -updated. In constrast, exported models (``.pt``) act as standalone files. -For long-term usage, you should export your model! Exporting a model is also necessary -if you want to use it in other frameworks, especially in molecular simulations -(see the :ref:`tutorials`). +updated. In constrast, exported models (``.pt``) act as standalone files. For long-term +usage, you should export your model! Exporting a model is also necessary if you want to +use it in other frameworks, especially in molecular simulations (see the +:ref:`tutorials`). diff --git a/src/metatrain/cli/export.py b/src/metatrain/cli/export.py index 378d93145..22a8ae3c5 100644 --- a/src/metatrain/cli/export.py +++ b/src/metatrain/cli/export.py @@ -5,9 +5,9 @@ import torch -from ..utils.architectures import find_all_architectures, import_architecture +from ..utils.architectures import find_all_architectures from ..utils.export import is_exported -from ..utils.io import check_file_extension +from ..utils.io import check_file_extension, load_model from .formatter import CustomHelpFormatter @@ -40,7 +40,10 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None: parser.add_argument( "path", type=str, - help="Saved model which should be exported", + help=( + "Saved model which should be exported. Path can be either a URL or a " + "local file." + ), ) parser.add_argument( "-o", @@ -55,14 +58,14 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None: def _prepare_export_model_args(args: argparse.Namespace) -> None: """Prepare arguments for export_model.""" - architecture_name = args.__dict__.pop("architecture_name") - architecture = import_architecture(architecture_name) - - args.model = architecture.__model__.load_checkpoint(args.__dict__.pop("path")) + args.model = load_model( + architecture_name=args.__dict__.pop("architecture_name"), + path=args.__dict__.pop("path"), + ) def export_model(model: Any, output: Union[Path, str] = "exported-model.pt") -> None: - """Export a trained model to allow it to make predictions. + """Export a trained model allowing it to make predictions. This includes predictions within molecular simulation engines. Exported models will be saved with a ``.pt`` file ending. If ``path`` does not end with this file diff --git a/src/metatrain/utils/export.py b/src/metatrain/utils/export.py index bc28c2719..717b47944 100644 --- a/src/metatrain/utils/export.py +++ b/src/metatrain/utils/export.py @@ -54,6 +54,10 @@ def is_exported(model: Any) -> bool: :param model: The model to check :return: :py:obj:`True` if the ``model`` has been exported, :py:obj:`False` otherwise. + + .. seealso:: + :py:func:`utils.io.is_exported_file ` + to verify if a saved model on disk is already already exported. """ # If the model is saved and loaded again, its type is RecursiveScriptModule if type(model) in [ diff --git a/src/metatrain/utils/io.py b/src/metatrain/utils/io.py index a11cbdd0e..cab8e6121 100644 --- a/src/metatrain/utils/io.py +++ b/src/metatrain/utils/io.py @@ -1,6 +1,13 @@ import warnings from pathlib import Path -from typing import Union +from typing import Any, Union +from urllib.parse import urlparse +from urllib.request import urlretrieve + +import torch +from metatensor.torch.atomistic import check_atomistic_model + +from .architectures import import_architecture def check_file_extension( @@ -30,3 +37,51 @@ def check_file_extension( return str(path_filename) else: return path_filename + + +def is_exported_file(path: str) -> bool: + """Check if a saved model file has been exported to a MetatensorAtomisticModel. + + :param path: model path + :return: :py:obj:`True` if the ``model`` has been exported, :py:obj:`False` + otherwise. + + .. seealso:: + :py:func:`utils.export.is_exported ` to + verify if an already loaded model is exported. + """ + try: + check_atomistic_model(str(path)) + return True + except ValueError: + return False + + +def load_model(architecture_name: str, path: Union[str, Path]) -> Any: + """Loads a module from an URL or a local file. + + :param name: name of the architecture + :param path: local or remote path to a model. For supported URL schemes see + :py:class`urllib.request` + :raises ValueError: if ``path`` is a YAML option file and no model + :raises ValueError: if the checkpoint saved in ``path`` does not math the given + ``architecture_name`` + """ + if Path(path).suffix in [".yaml", ".yml"]: + raise ValueError(f"path '{path}' seems to be a YAML option file and no model") + + if urlparse(str(path)).scheme: + path, _ = urlretrieve(str(path)) + + if is_exported_file(str(path)): + return torch.jit.load(str(path)) + else: # model is a checkpoint + architecture = import_architecture(architecture_name) + + try: + return architecture.__model__.load_checkpoint(str(path)) + except Exception as err: + raise ValueError( + f"path '{path}' is not a valid model file for the {architecture_name} " + "architecture" + ) from err diff --git a/tests/utils/test_io.py b/tests/utils/test_io.py index 85752c352..0ca1b8948 100644 --- a/tests/utils/test_io.py +++ b/tests/utils/test_io.py @@ -1,8 +1,16 @@ from pathlib import Path import pytest +from torch.jit._script import RecursiveScriptModule -from metatrain.utils.io import check_file_extension +from metatrain.experimental.soap_bpnn.model import SoapBpnn +from metatrain.utils.io import check_file_extension, is_exported_file, load_model + +from . import RESOURCES_PATH + + +def is_None(*args, **kwargs) -> None: + return None @pytest.mark.parametrize("filename", ["example.txt", Path("example.txt")]) @@ -21,3 +29,53 @@ def test_warning_on_missing_suffix(filename): assert str(result) == "example.txt" assert isinstance(result, type(filename)) + + +def test_is_exported_file(): + assert is_exported_file(RESOURCES_PATH / "model-32-bit.pt") + assert not is_exported_file(RESOURCES_PATH / "model-32-bit.ckpt") + + +@pytest.mark.parametrize( + "path", + [ + RESOURCES_PATH / "model-32-bit.ckpt", + str(RESOURCES_PATH / "model-32-bit.ckpt"), + f"file:{str(RESOURCES_PATH / "model-32-bit.ckpt")}", + ], +) +def test_load_model_checkpoint(path): + model = load_model("experimental.soap_bpnn", path) + assert type(model) is SoapBpnn + + +@pytest.mark.parametrize( + "path", + [ + RESOURCES_PATH / "model-32-bit.pt", + str(RESOURCES_PATH / "model-32-bit.pt"), + f"file:{str(RESOURCES_PATH / "model-32-bit.pt")}", + ], +) +def test_load_model_exported(path): + model = load_model("experimental.soap_bpnn", RESOURCES_PATH / "model-32-bit.pt") + assert type(model) is RecursiveScriptModule + + +@pytest.mark.parametrize("suffix", [".yml", ".yaml"]) +def test_load_model_yaml(suffix): + match = f"path 'foo{suffix}' seems to be a YAML option file and no model" + with pytest.raises(ValueError, match=match): + load_model("experimental.soap_bpnn", f"foo{suffix}") + + +def test_load_model_unknown_model(): + architecture_name = "experimental.pet" + path = RESOURCES_PATH / "model-32-bit.ckpt" + + match = ( + f"path '{path}' is not a valid model file for the {architecture_name} " + "architecture" + ) + with pytest.raises(ValueError, match=match): + load_model(architecture_name, path) diff --git a/tox.ini b/tox.ini index a38fe4949..f4e783f12 100644 --- a/tox.ini +++ b/tox.ini @@ -59,7 +59,9 @@ deps = pytest-cov pytest-xdist changedir = tests -extras = soap-bpnn # this model is used in the package tests +extras = # architectures used in the package tests + soap-bpnn + pet allowlist_externals = bash commands_pre = bash {toxinidir}/tests/resources/generate-outputs.sh commands = @@ -92,7 +94,7 @@ description = Run SOAP-BPNN tests with pytest passenv = * deps = pytest -extras = soap-bpnn +extras =soap-bpnn changedir = src/metatrain/experimental/soap_bpnn/tests/ commands = pytest {posargs}