Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic continuation #363

Merged
merged 6 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/src/advanced-concepts/auto-restarting.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Automatic restarting
====================

When restarting multiple times (for example, when training an expensive model
or running on an HPC cluster with short time limits), it is useful to be able
to train and restart multiple times with the same command.

In ``metatrain``, this functionality is provided via the ``--continue auto``
(or ``-c auto``) flag of ``mtt train``. This flag will automatically restart
the training from the last checkpoint, if one is found in the ``outputs/``
of the current directory. If no checkpoint is found, the training will start
from scratch.
1 change: 1 addition & 0 deletions docs/src/advanced-concepts/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ such as output naming, auxiliary outputs, and wrapper models.
output-naming
auxiliary-outputs
multi-gpu
auto-restarting
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies = [
"python-hostlist",
"torch",
"vesin",
"numpy < 2.0.0"
]

keywords = ["machine learning", "molecular modeling"]
Expand Down
4 changes: 3 additions & 1 deletion src/metatrain/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
export_model,
)
from .cli.train import _add_train_model_parser, _prepare_train_model_args, train_model
from .utils.distributed.logging import is_main_process
from .utils.logging import get_cli_input, setup_logging


Expand Down Expand Up @@ -81,7 +82,8 @@ def main():
if callable == "train_model":
# define and create `checkpoint_dir` based on current directory, date and time
checkpoint_dir = _datetime_output_path(now=datetime.now())
os.makedirs(checkpoint_dir, exist_ok=True) # exist_ok=True for distributed
if is_main_process():
os.makedirs(checkpoint_dir)
args.checkpoint_dir = checkpoint_dir

log_file = checkpoint_dir / "train.log"
Expand Down
36 changes: 35 additions & 1 deletion src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
import random
import time
from pathlib import Path
from typing import Dict, Optional, Union

Expand Down Expand Up @@ -74,7 +75,7 @@ def _add_train_model_parser(subparser: argparse._SubParsersAction) -> None:
"-c",
"--continue",
dest="continue_from",
type=str,
type=_process_continue_from,
required=False,
help="File to continue training from.",
)
Expand All @@ -98,6 +99,39 @@ def _prepare_train_model_args(args: argparse.Namespace) -> None:
args.options = OmegaConf.merge(args.options, override_options)


def _process_continue_from(continue_from: str) -> Optional[str]:
# covers the case where `continue_from` is `auto`
if continue_from == "auto":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should inform the user that the training is continued if the outputs directory is found or maybe error or warn if continue="auto" and now directory is found.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok! It's now there

# try to find the `outputs` directory; if it doesn't exist
# then we are not continuing from a previous run
if Path("outputs/").exists():
# take the latest year-month-day directory
dir = sorted(Path("outputs/").iterdir())[-1]
# take the latest hour-minute-second directory
dir = sorted(dir.iterdir())[-1]
# take the latest checkpoint. This cannot be done with
# `sorted` because some checkpoint files are named with
# the epoch number (e.g. `epoch_10.ckpt` would be before
# `epoch_8.ckpt`). We therefore sort by file creation time.
new_continue_from = str(
sorted(dir.glob("*.ckpt"), key=lambda f: f.stat().st_ctime)[-1]
)
logger.info(f"Auto-continuing from `{new_continue_from}`")
else:
new_continue_from = None
logger.info(
"Auto-continuation did not find any previous runs, "
"training from scratch"
)
# sleep for a few seconds to allow all processes to catch up. This is
# necessary because the `outputs` directory is created by the main
# process and the other processes might detect it by mistake if they're
# still executing this function
time.sleep(3)

return new_continue_from


def train_model(
options: Union[DictConfig, Dict],
output: str = "model.pt",
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def setup_logging(
stream_handler.setFormatter(formatter)
handlers.append(stream_handler)

if log_file:
if log_file and is_main_process():
log_file = check_file_extension(filename=log_file, extension=".log")
file_handler = logging.FileHandler(filename=str(log_file), encoding="utf-8")
file_handler.setFormatter(formatter)
Expand Down
46 changes: 45 additions & 1 deletion tests/cli/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from omegaconf import OmegaConf

from metatrain import RANDOM_SEED
from metatrain.cli.train import train_model
from metatrain.cli.train import _process_continue_from, train_model
from metatrain.utils.errors import ArchitectureError

from . import (
Expand Down Expand Up @@ -421,6 +421,50 @@ def test_continue(options, monkeypatch, tmp_path):
train_model(options, continue_from=MODEL_PATH_64_BIT)


def test_continue_auto(options, caplog, monkeypatch, tmp_path):
"""Test that continuing with the `auto` keyword results in
a continuation from the most recent checkpoint."""
monkeypatch.chdir(tmp_path)
shutil.copy(DATASET_PATH_QM9, "qm9_reduced_100.xyz")
caplog.set_level(logging.INFO)

# Make up an output directory with some checkpoints
true_checkpoint_dir = Path("outputs/2021-09-02/00-10-05")
true_checkpoint_dir.mkdir(parents=True, exist_ok=True)
# as well as some lower-priority checkpoints
fake_checkpoints_dirs = [
Path("outputs/2021-08-01/00-00-00"),
Path("outputs/2021-09-01/00-00-00"),
Path("outputs/2021-09-02/00-00-00"),
Path("outputs/2021-09-02/00-10-00"),
]
for fake_checkpoint_dir in fake_checkpoints_dirs:
fake_checkpoint_dir.mkdir(parents=True, exist_ok=True)

for i in range(1, 4):
shutil.copy(MODEL_PATH_64_BIT, true_checkpoint_dir / f"model_{i}.ckpt")
for fake_checkpoint_dir in fake_checkpoints_dirs:
shutil.copy(MODEL_PATH_64_BIT, fake_checkpoint_dir / f"model_{i}.ckpt")

train_model(options, continue_from=_process_continue_from("auto"))

assert "Loading checkpoint from" in caplog.text
assert str(true_checkpoint_dir) in caplog.text
assert "model_3.ckpt" in caplog.text


def test_continue_auto_no_outputs(options, caplog, monkeypatch, tmp_path):
"""Test that continuing with the `auto` keyword results in
training from scratch if `outputs/` is not present."""
monkeypatch.chdir(tmp_path)
shutil.copy(DATASET_PATH_QM9, "qm9_reduced_100.xyz")
caplog.set_level(logging.INFO)

train_model(options, continue_from=_process_continue_from("auto"))

assert "Loading checkpoint from" not in caplog.text


def test_continue_different_dataset(options, monkeypatch, tmp_path):
"""Test that continuing training from a checkpoint runs without an error raise
with a different dataset than the original."""
Expand Down