From c508d9d21988bf1824ce4cd5ac372b02420f2c9a Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 24 Oct 2024 16:35:55 +0200 Subject: [PATCH] Also save models in the `outputs/` directory --- docs/src/getting-started/usage.rst | 12 ++++++++---- pyproject.toml | 1 + src/metatrain/cli/train.py | 12 ++++++++++++ tests/cli/test_train_model.py | 14 ++++++++++++++ 4 files changed, 35 insertions(+), 4 deletions(-) diff --git a/docs/src/getting-started/usage.rst b/docs/src/getting-started/usage.rst index e51519e14..4a0df3e4f 100644 --- a/docs/src/getting-started/usage.rst +++ b/docs/src/getting-started/usage.rst @@ -45,15 +45,19 @@ training using the default hyperparameters of an SOAP BPNN model :language: yaml For each training run a new output directory in the format -``output/YYYY-MM-DD/HH-MM-SS`` based on the current *date* and *time* is created. We use -this output directory to store checkpoints, the ``train.log`` log file as well the -restart ``options_restart.yaml`` file. To start the training create an ``options.yaml`` -file in the current directory and type +``outputs/YYYY-MM-DD/HH-MM-SS`` based on the current *date* and *time* is created. We +use this output directory to store checkpoints, the ``train.log`` log file as well +the restart ``options_restart.yaml`` file. To start the training create an +``options.yaml`` file in the current directory and type .. literalinclude:: ../../../examples/basic_usage/usage.sh :language: bash :lines: 3-8 +After the training has finished, the ``mtt train`` command generates the ``model.ckpt`` +(final checkpoint) and ``model.pt`` (exported model) files in the current directory, as +well as in the ``output/YYYY-MM-DD/HH-MM-SS`` directory. + Evaluation ########## diff --git a/pyproject.toml b/pyproject.toml index fb4ab2b78..7e3f2d703 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "python-hostlist", "torch", "vesin", + "numpy < 2.0.0" ] keywords = ["machine learning", "molecular modeling"] diff --git a/src/metatrain/cli/train.py b/src/metatrain/cli/train.py index 5523b6ef4..5aeb579c2 100644 --- a/src/metatrain/cli/train.py +++ b/src/metatrain/cli/train.py @@ -4,6 +4,7 @@ import logging import os import random +import shutil from pathlib import Path from typing import Dict, Optional, Union @@ -397,6 +398,17 @@ def train_model( # the model is first saved and then reloaded 1) for good practice and 2) because # MetatensorAtomisticModel only torchscripts (makes faster) during save() + # Copy the exported model and the checkpoint also to the checkpoint directory + checkpoint_path = Path(checkpoint_dir) + if checkpoint_path != Path("."): + shutil.copy(output_checked, Path(checkpoint_dir) / output_checked) + if Path(f"{Path(output_checked).stem}.ckpt").exists(): + # inside the if because some models don't have a checkpoint (e.g., GAP) + shutil.copy( + f"{Path(output_checked).stem}.ckpt", + Path(checkpoint_dir) / f"{Path(output_checked).stem}.ckpt", + ) + ########################### # EVALUATE FINAL MODEL #### ########################### diff --git a/tests/cli/test_train_model.py b/tests/cli/test_train_model.py index a6abc91e3..9553db369 100644 --- a/tests/cli/test_train_model.py +++ b/tests/cli/test_train_model.py @@ -59,6 +59,20 @@ def test_train(capfd, monkeypatch, tmp_path, output): log_glob = glob.glob("outputs/*/*/train.log") assert len(log_glob) == 1 + model_name = "mymodel" if output == "mymodel.pt" else "model" + + # Test if the model is saved (both .pt and .ckpt) + pt_glob = glob.glob(f"{model_name}.pt") + assert len(pt_glob) == 1 + ckpt_glob = glob.glob(f"{model_name}.ckpt") + assert len(ckpt_glob) == 1 + + # Test if they are also saved to the outputs/ directory + pt_glob = glob.glob(f"outputs/*/*/{model_name}.pt") + assert len(pt_glob) == 1 + ckpt_glob = glob.glob(f"outputs/*/*/{model_name}.ckpt") + assert len(ckpt_glob) == 1 + # Test if extensions are saved extensions_glob = glob.glob("extensions/") assert len(extensions_glob) == 1