From f79a36d5aa004e4f9fb7d7f575ca297300bc2e58 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 17 Oct 2024 21:37:24 +0200 Subject: [PATCH 1/6] Automatic continuation --- src/metatrain/cli/train.py | 24 +++++++++++++++++++ tests/cli/test_train_model.py | 44 +++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/src/metatrain/cli/train.py b/src/metatrain/cli/train.py index 5523b6ef..29de6b1a 100644 --- a/src/metatrain/cli/train.py +++ b/src/metatrain/cli/train.py @@ -98,6 +98,29 @@ def _prepare_train_model_args(args: argparse.Namespace) -> None: args.options = OmegaConf.merge(args.options, override_options) +def _process_continue_from(continue_from: Optional[str]) -> Optional[str]: + # covers the case where `continue_from` is `auto` + if continue_from == "auto": + # try to find the `outputs` directory; if it doesn't exist + # then we are not continuing from a previous run + if Path("outputs/").exists(): + # take the latest day directory + dir = sorted(Path("outputs/").iterdir())[-1] + # take the latest second directory + dir = sorted(dir.iterdir())[-1] + # take the latest checkpoint. This cannot be done with + # `sorted` because some checkpoint files are named with + # the epoch number (e.g. `epoch_10.ckpt` would be before + # `epoch_8.ckpt`). We therefore sort by file creation time. + continue_from = str( + sorted(dir.glob("*.ckpt"), key=lambda f: f.stat().st_ctime)[-1] + ) + else: + continue_from = None + + return continue_from + + def train_model( options: Union[DictConfig, Dict], output: str = "model.pt", @@ -334,6 +357,7 @@ def train_model( logger.info("Setting up model") try: + continue_from = _process_continue_from(continue_from) if continue_from is not None: logger.info(f"Loading checkpoint from `{continue_from}`") trainer = Trainer.load_checkpoint(continue_from, hypers["training"]) diff --git a/tests/cli/test_train_model.py b/tests/cli/test_train_model.py index a6abc91e..5fec166a 100644 --- a/tests/cli/test_train_model.py +++ b/tests/cli/test_train_model.py @@ -421,6 +421,50 @@ def test_continue(options, monkeypatch, tmp_path): train_model(options, continue_from=MODEL_PATH_64_BIT) +def test_continue_auto(options, caplog, monkeypatch, tmp_path): + """Test that continuing with the `auto` keyword results in + a continuation from the most recent checkpoint.""" + monkeypatch.chdir(tmp_path) + shutil.copy(DATASET_PATH_QM9, "qm9_reduced_100.xyz") + caplog.set_level(logging.INFO) + + # Make up an output directory with some checkpoints + true_checkpoint_dir = Path("outputs/2021-09-02/00-10-05") + true_checkpoint_dir.mkdir(parents=True, exist_ok=True) + # as well as some lower-priority checkpoints + fake_checkpoints_dirs = [ + Path("outputs/2021-08-01/00-00-00"), + Path("outputs/2021-09-01/00-00-00"), + Path("outputs/2021-09-02/00-00-00"), + Path("outputs/2021-09-02/00-10-00"), + ] + for fake_checkpoint_dir in fake_checkpoints_dirs: + fake_checkpoint_dir.mkdir(parents=True, exist_ok=True) + + for i in range(1, 4): + shutil.copy(MODEL_PATH_64_BIT, true_checkpoint_dir / f"model_{i}.ckpt") + for fake_checkpoint_dir in fake_checkpoints_dirs: + shutil.copy(MODEL_PATH_64_BIT, fake_checkpoint_dir / f"model_{i}.ckpt") + + train_model(options, continue_from="auto") + + assert "Loading checkpoint from" in caplog.text + assert str(true_checkpoint_dir) in caplog.text + assert "model_3.ckpt" in caplog.text + + +def test_continue_auto_no_outputs(options, caplog, monkeypatch, tmp_path): + """Test that continuing with the `auto` keyword results in + training from scratch if `outputs/` is not present.""" + monkeypatch.chdir(tmp_path) + shutil.copy(DATASET_PATH_QM9, "qm9_reduced_100.xyz") + caplog.set_level(logging.INFO) + + train_model(options, continue_from="auto") + + assert "Loading checkpoint from" not in caplog.text + + def test_continue_different_dataset(options, monkeypatch, tmp_path): """Test that continuing training from a checkpoint runs without an error raise with a different dataset than the original.""" From 3c136bfa693897dd4e0b90cbf54027fd02177bd9 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 18 Oct 2024 06:24:19 +0200 Subject: [PATCH 2/6] Call processing function before new `outputs/` directory is created --- src/metatrain/cli/train.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/metatrain/cli/train.py b/src/metatrain/cli/train.py index 29de6b1a..66abef34 100644 --- a/src/metatrain/cli/train.py +++ b/src/metatrain/cli/train.py @@ -74,7 +74,7 @@ def _add_train_model_parser(subparser: argparse._SubParsersAction) -> None: "-c", "--continue", dest="continue_from", - type=str, + type=_process_continue_from, required=False, help="File to continue training from.", ) @@ -98,7 +98,7 @@ def _prepare_train_model_args(args: argparse.Namespace) -> None: args.options = OmegaConf.merge(args.options, override_options) -def _process_continue_from(continue_from: Optional[str]) -> Optional[str]: +def _process_continue_from(continue_from: str) -> Optional[str]: # covers the case where `continue_from` is `auto` if continue_from == "auto": # try to find the `outputs` directory; if it doesn't exist @@ -112,13 +112,13 @@ def _process_continue_from(continue_from: Optional[str]) -> Optional[str]: # `sorted` because some checkpoint files are named with # the epoch number (e.g. `epoch_10.ckpt` would be before # `epoch_8.ckpt`). We therefore sort by file creation time. - continue_from = str( + new_continue_from = str( sorted(dir.glob("*.ckpt"), key=lambda f: f.stat().st_ctime)[-1] ) else: - continue_from = None + new_continue_from = None - return continue_from + return new_continue_from def train_model( @@ -357,7 +357,6 @@ def train_model( logger.info("Setting up model") try: - continue_from = _process_continue_from(continue_from) if continue_from is not None: logger.info(f"Loading checkpoint from `{continue_from}`") trainer = Trainer.load_checkpoint(continue_from, hypers["training"]) From 427d3d85dc3533200952a644f8864fcb89301d92 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 18 Oct 2024 06:40:21 +0200 Subject: [PATCH 3/6] Make it distributed-proof --- src/metatrain/__main__.py | 4 +++- src/metatrain/utils/logging.py | 2 +- tests/cli/test_train_model.py | 6 +++--- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/metatrain/__main__.py b/src/metatrain/__main__.py index 40b3eae8..0aaacade 100644 --- a/src/metatrain/__main__.py +++ b/src/metatrain/__main__.py @@ -16,6 +16,7 @@ export_model, ) from .cli.train import _add_train_model_parser, _prepare_train_model_args, train_model +from .utils.distributed.logging import is_main_process from .utils.logging import get_cli_input, setup_logging @@ -81,7 +82,8 @@ def main(): if callable == "train_model": # define and create `checkpoint_dir` based on current directory, date and time checkpoint_dir = _datetime_output_path(now=datetime.now()) - os.makedirs(checkpoint_dir, exist_ok=True) # exist_ok=True for distributed + if is_main_process(): + os.makedirs(checkpoint_dir) args.checkpoint_dir = checkpoint_dir log_file = checkpoint_dir / "train.log" diff --git a/src/metatrain/utils/logging.py b/src/metatrain/utils/logging.py index dbaf0356..cf380154 100644 --- a/src/metatrain/utils/logging.py +++ b/src/metatrain/utils/logging.py @@ -221,7 +221,7 @@ def setup_logging( stream_handler.setFormatter(formatter) handlers.append(stream_handler) - if log_file: + if log_file and is_main_process(): log_file = check_file_extension(filename=log_file, extension=".log") file_handler = logging.FileHandler(filename=str(log_file), encoding="utf-8") file_handler.setFormatter(formatter) diff --git a/tests/cli/test_train_model.py b/tests/cli/test_train_model.py index 5fec166a..5e33c5c7 100644 --- a/tests/cli/test_train_model.py +++ b/tests/cli/test_train_model.py @@ -12,7 +12,7 @@ from omegaconf import OmegaConf from metatrain import RANDOM_SEED -from metatrain.cli.train import train_model +from metatrain.cli.train import _process_continue_from, train_model from metatrain.utils.errors import ArchitectureError from . import ( @@ -446,7 +446,7 @@ def test_continue_auto(options, caplog, monkeypatch, tmp_path): for fake_checkpoint_dir in fake_checkpoints_dirs: shutil.copy(MODEL_PATH_64_BIT, fake_checkpoint_dir / f"model_{i}.ckpt") - train_model(options, continue_from="auto") + train_model(options, continue_from=_process_continue_from("auto")) assert "Loading checkpoint from" in caplog.text assert str(true_checkpoint_dir) in caplog.text @@ -460,7 +460,7 @@ def test_continue_auto_no_outputs(options, caplog, monkeypatch, tmp_path): shutil.copy(DATASET_PATH_QM9, "qm9_reduced_100.xyz") caplog.set_level(logging.INFO) - train_model(options, continue_from="auto") + train_model(options, continue_from=_process_continue_from("auto")) assert "Loading checkpoint from" not in caplog.text From ed0d8e0dc69547a9b26f25fdf5d2628ce5122481 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 24 Oct 2024 14:30:06 +0200 Subject: [PATCH 4/6] Documentation --- docs/src/advanced-concepts/auto-restarting.rst | 12 ++++++++++++ docs/src/advanced-concepts/index.rst | 1 + src/metatrain/cli/train.py | 11 +++++++++++ 3 files changed, 24 insertions(+) create mode 100644 docs/src/advanced-concepts/auto-restarting.rst diff --git a/docs/src/advanced-concepts/auto-restarting.rst b/docs/src/advanced-concepts/auto-restarting.rst new file mode 100644 index 00000000..855dab32 --- /dev/null +++ b/docs/src/advanced-concepts/auto-restarting.rst @@ -0,0 +1,12 @@ +Automatic restarting +==================== + +When restarting multiple times (for example, when training an expensive model +or running on an HPC cluster with short time limits), it is useful to be able +to train and restart multiple times with the same command. + +In ``metatrain``, this functionality is provided via the ``--continue auto`` +(or ``-c auto``) flag of ``mtt train``. This flag will automatically restart +the training from the last checkpoint, if one is found in the ``outputs/`` +of the current directory. If no checkpoint is found, the training will start +from scratch. diff --git a/docs/src/advanced-concepts/index.rst b/docs/src/advanced-concepts/index.rst index a138e6c3..cfd9f494 100644 --- a/docs/src/advanced-concepts/index.rst +++ b/docs/src/advanced-concepts/index.rst @@ -11,3 +11,4 @@ such as output naming, auxiliary outputs, and wrapper models. output-naming auxiliary-outputs multi-gpu + auto-restarting diff --git a/src/metatrain/cli/train.py b/src/metatrain/cli/train.py index 66abef34..6f6a9aed 100644 --- a/src/metatrain/cli/train.py +++ b/src/metatrain/cli/train.py @@ -4,6 +4,7 @@ import logging import os import random +import time from pathlib import Path from typing import Dict, Optional, Union @@ -115,8 +116,18 @@ def _process_continue_from(continue_from: str) -> Optional[str]: new_continue_from = str( sorted(dir.glob("*.ckpt"), key=lambda f: f.stat().st_ctime)[-1] ) + logger.info(f"Auto-continuing from `{new_continue_from}`") else: new_continue_from = None + logger.info( + "Auto-continuation did not find any previous runs, " + "training from scratch" + ) + # sleep for a few seconds to allow all processes to catch up. This is + # necessary because the `outputs` directory is created by the main + # process and the other processes might detect it by mistake if they're + # still executing this function + time.sleep(3) return new_continue_from From 55d36191ac94d338d1056706c15d7ca421fe4ef8 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 24 Oct 2024 14:44:40 +0200 Subject: [PATCH 5/6] Downgrade numpy --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 07a75e75..d30700fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "python-hostlist", "torch", "vesin", + "numpy < 2.0.0" ] keywords = ["machine learning", "molecular modeling"] From 811db9a562a3e26e70e3c9bf589c438b6cf9322d Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 24 Oct 2024 16:18:33 +0200 Subject: [PATCH 6/6] Better comment --- src/metatrain/cli/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/metatrain/cli/train.py b/src/metatrain/cli/train.py index 6f6a9aed..3325560a 100644 --- a/src/metatrain/cli/train.py +++ b/src/metatrain/cli/train.py @@ -105,9 +105,9 @@ def _process_continue_from(continue_from: str) -> Optional[str]: # try to find the `outputs` directory; if it doesn't exist # then we are not continuing from a previous run if Path("outputs/").exists(): - # take the latest day directory + # take the latest year-month-day directory dir = sorted(Path("outputs/").iterdir())[-1] - # take the latest second directory + # take the latest hour-minute-second directory dir = sorted(dir.iterdir())[-1] # take the latest checkpoint. This cannot be done with # `sorted` because some checkpoint files are named with