Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Nov 5, 2024
1 parent 3d73414 commit 9109dfb
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 45 deletions.
20 changes: 2 additions & 18 deletions docs/src/dev-docs/utils/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,8 @@ This is the API for the ``utils`` module of ``metatrain``.

.. toctree::
:maxdepth: 1
:glob:

additive/index
data/index
architectures
devices
dtype
errors
evaluate_model
external_naming
export
io
jsonschema
logging
loss
metrics
neighbor_lists
omegaconf
output_gradient
per_atom
transfer
units
./*
9 changes: 5 additions & 4 deletions src/metatrain/cli/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,12 @@ def export_model(model: Any, output: Union[Path, str] = "exported-model.pt") ->
extensions_path = str(Path("extensions/").absolute().resolve())

if is_atomistic_model(model):
# recreate a valid AtomisticModel for export including extensions
model = MetatensorAtomisticModel(
# recreate a valid AtomisticModel including extensions for export
atomistic_model = MetatensorAtomisticModel(
model.module, model.metadata(), model.capabilities()
)
else:
atomistic_model = model.export()

model = model.export()
model.save(path, collect_extensions=extensions_path)
atomistic_model.save(path, collect_extensions=extensions_path)
logger.info(f"Model exported to '{path}' and extensions to '{extensions_path}'")
3 changes: 1 addition & 2 deletions src/metatrain/experimental/pet/tests/test_exported.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from metatrain.experimental.pet import PET as WrappedPET
from metatrain.utils.architectures import get_default_hypers
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict
from metatrain.utils.export import export
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
Expand Down Expand Up @@ -54,7 +53,7 @@ def test_to(device):
supported_devices=["cpu", "cuda"],
)

exported = export(model, capabilities)
exported = model.export()
exported.to(device=device, dtype=dtype)

system = System(
Expand Down
14 changes: 1 addition & 13 deletions tests/utils/test_evaluate_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import pytest
import torch
from metatensor.torch.atomistic import ModelCapabilities

from metatrain.experimental.soap_bpnn import __model__
from metatrain.utils.data import DatasetInfo, TargetInfo, read_systems
from metatrain.utils.evaluate_model import evaluate_model
from metatrain.utils.export import export
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
Expand Down Expand Up @@ -37,17 +35,7 @@ def test_evaluate_model(training, exported):
model = __model__(model_hypers=MODEL_HYPERS, dataset_info=dataset_info)

if exported:

capabilities = ModelCapabilities(
length_unit=model.dataset_info.length_unit,
outputs=model.outputs,
atomic_types=list(model.dataset_info.atomic_types),
supported_devices=model.__supported_devices__,
interaction_range=model.hypers["soap"]["cutoff"],
dtype="float32",
)

model = export(model, capabilities)
model = model.export()
requested_neighbor_lists = get_requested_neighbor_lists(model)
systems = [
get_system_with_neighbor_lists(system, requested_neighbor_lists)
Expand Down
11 changes: 5 additions & 6 deletions tests/utils/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,10 @@ def test_load_model_unknown_model():

def test_extensions_directory_and_architecture_name():
match = (
r"Both ``extensions_directory`` \('.'\) and ``architecture_name`` ('foo') are "
r"given which are mutually exclusive. An ``extensions_directory`` is only "
r"required for *exported* models while an ``architecture_name`` is only needed "
r"for model *checkpoints*."
r"Both ``extensions_directory`` \('.'\) and ``architecture_name`` \('foo'\) "
r"are given which are mutually exclusive. An ``extensions_directory`` is only "
r"required for \*exported\* models while an ``architecture_name`` is only "
r"needed for model \*checkpoints\*."
)
with pytest.raises(ValueError, match=match):
load_model("model.pt", extensions_directory=".", architecture_name="foo"
)
load_model("model.pt", extensions_directory=".", architecture_name="foo")
4 changes: 2 additions & 2 deletions tests/utils/test_llpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_llpr(tmpdir):

def test_llpr_covariance_as_pseudo_hessian(tmpdir):

model = load_atomistic_model(
model = load_model(
str(RESOURCES_PATH / "model-64-bit.pt"),
extensions_directory=str(RESOURCES_PATH / "extensions/"),
)
Expand Down Expand Up @@ -251,7 +251,7 @@ def test_llpr_covariance_as_pseudo_hessian(tmpdir):
file=str(tmpdir / "llpr_model.pt"),
collect_extensions=str(tmpdir / "extensions"),
)
llpr_model = load_atomistic_model(
llpr_model = load_model(
str(tmpdir / "llpr_model.pt"), extensions_directory=str(tmpdir / "extensions")
)

Expand Down

0 comments on commit 9109dfb

Please sign in to comment.