Skip to content

Commit

Permalink
allow exporting models on remote locations
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Oct 24, 2024
1 parent 6af1bed commit 47b3798
Show file tree
Hide file tree
Showing 13 changed files with 236 additions and 88 deletions.
22 changes: 16 additions & 6 deletions docs/src/getting-started/checkpoints.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,30 @@ or
mtt train options.yaml -c model.ckpt
Checkpoints can also be turned into exported models using the ``export`` sub-command.
The command requires the *architecture name* and the saved checkpoint *path* as
positional arguments

.. code-block:: bash
mtt export model.ckpt -o model.pt
mtt export experimental.soap_bpnn model.ckpt -o model.pt
or

.. code-block:: bash
mtt export model.ckpt --output model.pt
mtt export experimental.soap_bpnn model.ckpt --output model.pt
For a export of distribution of models the ``export`` command also supports parsing
models from remote locations. To export a remote model you can provide a URL instead of
a file path.

.. code-block:: bash
mtt export experimental.soap_bpnn https://my.url.com/model.ckpt --output model.pt
Keep in mind that a checkpoint (``.ckpt``) is only a temporary file, which can have
several dependencies and may become unusable if the corresponding architecture is
updated. In constrast, exported models (``.pt``) act as standalone files.
For long-term usage, you should export your model! Exporting a model is also necessary
if you want to use it in other frameworks, especially in molecular simulations
(see the :ref:`tutorials`).
updated. In constrast, exported models (``.pt``) act as standalone files. For long-term
usage, you should export your model! Exporting a model is also necessary if you want to
use it in other frameworks, especially in molecular simulations (see the
:ref:`tutorials`).
13 changes: 7 additions & 6 deletions examples/programmatic/llpr/llpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,19 @@
#

import torch
from metatensor.torch.atomistic import load_atomistic_model

from metatrain.utils.io import load_model


# %%
#
# Exported models can be loaded using the `load_atomistic_model` function from the
# metatensor.torch.atomistic` module. The function requires the path to the exported
# model and, for many models, also the path to the respective extensions directory.
# Both are produced during the training process.
# Models can be loaded using the :func:`metatensor.utils.io.load_model` function from
# the. For already exported models The function requires the path to the exported model
# and, for many models, also the path to the respective extensions directory. Both are
# produced during the training process.


model = load_atomistic_model("model.pt", extensions_directory="extensions/")
model = load_model("model.pt", extensions_directory="extensions/")

# %%
#
Expand Down
4 changes: 3 additions & 1 deletion src/metatrain/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from ..utils.errors import ArchitectureError
from ..utils.evaluate_model import evaluate_model
from ..utils.io import load_model
from ..utils.logging import MetricLogger
from ..utils.metrics import MAEAccumulator, RMSEAccumulator
from ..utils.neighbor_lists import (
Expand Down Expand Up @@ -95,7 +96,8 @@ def _add_eval_model_parser(subparser: argparse._SubParsersAction) -> None:
def _prepare_eval_model_args(args: argparse.Namespace) -> None:
"""Prepare arguments for eval_model."""
args.options = OmegaConf.load(args.options)
args.model = metatensor.torch.atomistic.load_atomistic_model(
# models for evaluation are already exported and don't need a name
args.model = load_model(
path=args.__dict__.pop("path"),
extensions_directory=args.__dict__.pop("extensions_directory"),
)
Expand Down
47 changes: 26 additions & 21 deletions src/metatrain/cli/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
from pathlib import Path
from typing import Any, Union

import torch
from metatensor.torch.atomistic import MetatensorAtomisticModel, is_atomistic_model

from ..utils.architectures import find_all_architectures, import_architecture
from ..utils.export import is_exported
from ..utils.io import check_file_extension
from ..utils.architectures import find_all_architectures
from ..utils.io import check_file_extension, load_model
from .formatter import CustomHelpFormatter


Expand Down Expand Up @@ -40,7 +39,10 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None:
parser.add_argument(
"path",
type=str,
help="Saved model which should be exported",
help=(
"Saved model which should be exported. Path can be either a URL or a "
"local file."
),
)
parser.add_argument(
"-o",
Expand All @@ -55,14 +57,14 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None:

def _prepare_export_model_args(args: argparse.Namespace) -> None:
"""Prepare arguments for export_model."""
architecture_name = args.__dict__.pop("architecture_name")
architecture = import_architecture(architecture_name)

args.model = architecture.__model__.load_checkpoint(args.__dict__.pop("path"))
args.model = load_model(
path=args.__dict__.pop("path"),
architecture_name=args.__dict__.pop("architecture_name"),
)


def export_model(model: Any, output: Union[Path, str] = "exported-model.pt") -> None:
"""Export a trained model to allow it to make predictions.
"""Export a trained model allowing it to make predictions.
This includes predictions within molecular simulation engines. Exported models will
be saved with a ``.pt`` file ending. If ``path`` does not end with this file
Expand All @@ -71,16 +73,19 @@ def export_model(model: Any, output: Union[Path, str] = "exported-model.pt") ->
:param model: model to be exported
:param output: path to save the exported model
"""
path = str(check_file_extension(filename=output, extension=".pt"))
path = str(
Path(check_file_extension(filename=output, extension=".pt"))
.absolute()
.resolve()
)
extensions_path = str(Path("extensions/").absolute().resolve())

if is_exported(model):
logger.info(f"The model is already exported. Saving it to `{path}`.")
torch.jit.save(model, path)
else:
extensions_path = "extensions/"
logger.info(
f"Exporting model to '{path}' and extensions to '{extensions_path}'"
if is_atomistic_model(model):
# recreate a valid AtomisticModel for export including extensions
model = MetatensorAtomisticModel(
model.module, model.metadata(), model.capabilities()
)
mts_atomistic_model = model.export()
mts_atomistic_model.save(path, collect_extensions=extensions_path)
logger.info("Model exported successfully")

model = model.export()
model.save(path, collect_extensions=extensions_path)
logger.info(f"Model exported to '{path}' and extensions to '{extensions_path}'")
8 changes: 4 additions & 4 deletions src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import numpy as np
import torch
from metatensor.torch.atomistic import load_atomistic_model
from omegaconf import DictConfig, OmegaConf

from .. import PACKAGE_ROOT
Expand All @@ -29,7 +28,7 @@
from ..utils.devices import pick_devices
from ..utils.distributed.logging import is_main_process
from ..utils.errors import ArchitectureError
from ..utils.io import check_file_extension
from ..utils.io import check_file_extension, load_model
from ..utils.jsonschema import validate
from ..utils.omegaconf import BASE_OPTIONS, check_units, expand_dataset_config
from .eval import _eval_targets
Expand Down Expand Up @@ -401,8 +400,9 @@ def train_model(
# EVALUATE FINAL MODEL ####
###########################

mts_atomistic_model = load_atomistic_model(
str(output_checked), extensions_directory=extensions_path
mts_atomistic_model = load_model(
path=output_checked,
extensions_directory=extensions_path,
)
mts_atomistic_model = mts_atomistic_model.to(final_device)

Expand Down
6 changes: 3 additions & 3 deletions src/metatrain/utils/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
ModelEvaluationOptions,
ModelOutput,
System,
is_atomistic_model,
register_autograd_neighbors,
)

from .data import TargetInfoDict
from .export import is_exported
from .output_gradient import compute_gradient


Expand Down Expand Up @@ -221,7 +221,7 @@ def _strain_gradients_to_block(gradients_list):
def _get_outputs(
model: Union[torch.nn.Module, torch.jit._script.RecursiveScriptModule]
):
if is_exported(model):
if is_atomistic_model(model):
return model.capabilities().outputs
else:
return model.outputs
Expand All @@ -237,7 +237,7 @@ def _get_model_outputs(
targets: TargetInfoDict,
check_consistency: bool,
) -> Dict[str, TensorMap]:
if is_exported(model):
if is_atomistic_model(model):
# put together an EvaluationOptions object
options = ModelEvaluationOptions(
length_unit="", # this is only needed for unit conversions in MD engines
Expand Down
21 changes: 2 additions & 19 deletions src/metatrain/utils/export.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import warnings
from typing import Any

import torch
from metatensor.torch.atomistic import (
MetatensorAtomisticModel,
ModelCapabilities,
ModelMetadata,
is_atomistic_model,
)


Expand All @@ -27,7 +27,7 @@ def export(
:returns: exprted model
"""

if is_exported(model):
if is_atomistic_model(model):
return model

if model_capabilities.length_unit == "":
Expand All @@ -46,20 +46,3 @@ def export(
)

return MetatensorAtomisticModel(model.eval(), ModelMetadata(), model_capabilities)


def is_exported(model: Any) -> bool:
"""Check if a model has been exported to a MetatensorAtomisticModel.
:param model: The model to check
:return: :py:obj:`True` if the ``model`` has been exported, :py:obj:`False`
otherwise.
"""
# If the model is saved and loaded again, its type is RecursiveScriptModule
if type(model) in [
MetatensorAtomisticModel,
torch.jit._script.RecursiveScriptModule,
]:
return True
else:
return False
86 changes: 85 additions & 1 deletion src/metatrain/utils/io.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import warnings
from pathlib import Path
from typing import Union
from typing import Any, Optional, Union
from urllib.parse import urlparse
from urllib.request import urlretrieve

from metatensor.torch.atomistic import check_atomistic_model, load_atomistic_model

from .architectures import import_architecture


def check_file_extension(
Expand Down Expand Up @@ -30,3 +36,81 @@ def check_file_extension(
return str(path_filename)
else:
return path_filename


def is_exported_file(path: str) -> bool:
"""Check if a saved model file has been exported to a MetatensorAtomisticModel.
:param path: model path
:return: :py:obj:`True` if the ``model`` has been exported, :py:obj:`False`
otherwise.
.. seealso::
:py:func:`metatensor.torch.atomistic.is_atomistic_model` to verify if an already
loaded model is exported.
"""
try:
check_atomistic_model(str(path))
return True
except ValueError:
return False


def load_model(
path: Union[str, Path],
extensions_directory: Optional[Union[str, Path]] = None,
architecture_name: Optional[str] = None,
) -> Any:
"""Loads a module from an URL or a local file.
The function can load model checkpoint as well as already exported models.
:param path: local or remote path to a model. For supported URL schemes see
:py:class`urllib.request`
:param extensions_directory: path to a directory containing all extensions required
by an *exported* model
:param architecture_name: name of the architecture required for loading from a
*checkpoint*.
:raises ValueError: if both an ``extensions_directory`` and ``architecture_name``
are given
:raises ValueError: if ``path`` is a YAML option file and no model
:raises ValueError: if no ``archietcture_name`` is given for loading a checkpoint
:raises ValueError: if the checkpoint saved in ``path`` does not math the given
``architecture_name``
"""
if extensions_directory is not None and architecture_name is not None:
raise ValueError(
f"Both ``extensions_directory`` ('{str(extensions_directory)}') and "
f"``architecture_name`` ('{architecture_name}') are given which are "
"mutually exclusive. An ``extensions_directory`` is only required for "
"*exported* models while an ``architecture_name`` is only needed for model "
"*checkpoints*."
)

if Path(path).suffix in [".yaml", ".yml"]:
raise ValueError(f"path '{path}' seems to be a YAML option file and no model")

if urlparse(str(path)).scheme:
path, _ = urlretrieve(str(path))

if is_exported_file(str(path)):
return load_atomistic_model(
str(path), extensions_directory=extensions_directory
)
else: # model is a checkpoint
if architecture_name is None:
raise ValueError(
f"path '{path}' seems to be a checkpointed model but no "
"`architecture_name` was given"
)
architecture = import_architecture(architecture_name)

try:
return architecture.__model__.load_checkpoint(str(path))
except Exception as err:
raise ValueError(
f"path '{path}' is not a valid model file for the {architecture_name} "
"architecture"
) from err
Loading

0 comments on commit 47b3798

Please sign in to comment.