diff --git a/docs/src/conf.py b/docs/src/conf.py index c7e5abacb..e9373df1c 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -9,7 +9,7 @@ # When importing metatensor-torch, this will change the definition of the classes -# to include in the documentation +# to include the documentation os.environ["METATENSOR_IMPORT_FOR_SPHINX"] = "1" os.environ["RASCALINE_IMPORT_FOR_SPHINX"] = "1" diff --git a/docs/src/dev-docs/new-architecture.rst b/docs/src/dev-docs/new-architecture.rst index 2921a59a1..7d744d294 100644 --- a/docs/src/dev-docs/new-architecture.rst +++ b/docs/src/dev-docs/new-architecture.rst @@ -32,10 +32,10 @@ to these lines checkpoint_dir="path", ) - model.save_checkpoint("final.ckpt") + model.save_checkpoint("model.ckpt") mts_atomistic_model = model.export() - mts_atomistic_model.export("path", collect_extensions="extensions-dir/") + mts_atomistic_model.export("model.pt", collect_extensions="extensions/") In order to follow this, a new architectures has two define two classes diff --git a/src/metatensor/models/__main__.py b/src/metatensor/models/__main__.py index 1c2157ae0..a4bc25c02 100644 --- a/src/metatensor/models/__main__.py +++ b/src/metatensor/models/__main__.py @@ -20,21 +20,6 @@ from .utils.logging import setup_logging -# This import is necessary to avoid errors when loading an -# exported alchemical model, which depends on sphericart-torch. -# TODO: Remove this when https://github.com/lab-cosmo/metatensor/issues/512 -# is ready -try: - import sphericart.torch # noqa: F401 -except ImportError: - pass - -try: - import rascaline.torch # noqa: F401 -except ImportError: - pass - - logger = logging.getLogger(__name__) diff --git a/src/metatensor/models/cli/eval.py b/src/metatensor/models/cli/eval.py index b6f5ad5e8..7633d554e 100644 --- a/src/metatensor/models/cli/eval.py +++ b/src/metatensor/models/cli/eval.py @@ -58,7 +58,7 @@ def _add_eval_model_parser(subparser: argparse._SubParsersAction) -> None: ) parser.add_argument( "-e", - "--extdir", + "--extensions-dir", type=str, required=False, dest="extensions_directory", diff --git a/src/metatensor/models/cli/export.py b/src/metatensor/models/cli/export.py index 5c452b425..121cc5a2c 100644 --- a/src/metatensor/models/cli/export.py +++ b/src/metatensor/models/cli/export.py @@ -63,4 +63,4 @@ def export_model(model: Any, output: Union[Path, str] = "exported-model.pt") -> torch.jit.save(model, path) else: mts_atomistic_model = model.export() - mts_atomistic_model.export(path) + mts_atomistic_model.export(path, collect_extensions="extensions") diff --git a/src/metatensor/models/cli/train.py b/src/metatensor/models/cli/train.py index 57683dfc6..2e55f79e3 100644 --- a/src/metatensor/models/cli/train.py +++ b/src/metatensor/models/cli/train.py @@ -389,7 +389,7 @@ def train_model( raise ArchitectureError(e) mts_atomistic_model = model.export() - mts_atomistic_model.export(str(output_checked)) + mts_atomistic_model.export(str(output_checked), collect_extensions="extensions") ########################### # EVALUATE FINAL MODEL #### diff --git a/src/metatensor/models/utils/evaluate_model.py b/src/metatensor/models/utils/evaluate_model.py index 22e5689b1..aaf5e8fe2 100644 --- a/src/metatensor/models/utils/evaluate_model.py +++ b/src/metatensor/models/utils/evaluate_model.py @@ -22,7 +22,8 @@ "ignore", category=UserWarning, message="neighbor", -) # TODO: this is not filtering out the warning for some reason +) # TODO: this is not filtering out the warning for some reason, therefore: +warnings.filterwarnings("ignore") # ignore all warnings if not in debug mode def evaluate_model(