Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic continuation #363

Merged
merged 6 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/metatrain/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
export_model,
)
from .cli.train import _add_train_model_parser, _prepare_train_model_args, train_model
from .utils.distributed.logging import is_main_process
from .utils.logging import get_cli_input, setup_logging


Expand Down Expand Up @@ -81,7 +82,8 @@ def main():
if callable == "train_model":
# define and create `checkpoint_dir` based on current directory, date and time
checkpoint_dir = _datetime_output_path(now=datetime.now())
os.makedirs(checkpoint_dir, exist_ok=True) # exist_ok=True for distributed
if is_main_process():
os.makedirs(checkpoint_dir)
args.checkpoint_dir = checkpoint_dir

log_file = checkpoint_dir / "train.log"
Expand Down
25 changes: 24 additions & 1 deletion src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _add_train_model_parser(subparser: argparse._SubParsersAction) -> None:
"-c",
"--continue",
dest="continue_from",
type=str,
type=_process_continue_from,
required=False,
help="File to continue training from.",
)
Expand All @@ -98,6 +98,29 @@ def _prepare_train_model_args(args: argparse.Namespace) -> None:
args.options = OmegaConf.merge(args.options, override_options)


def _process_continue_from(continue_from: str) -> Optional[str]:
# covers the case where `continue_from` is `auto`
if continue_from == "auto":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should inform the user that the training is continued if the outputs directory is found or maybe error or warn if continue="auto" and now directory is found.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok! It's now there

# try to find the `outputs` directory; if it doesn't exist
# then we are not continuing from a previous run
if Path("outputs/").exists():
# take the latest day directory
dir = sorted(Path("outputs/").iterdir())[-1]
# take the latest second directory
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# take the latest second directory
# take the latest time directory

Copy link
Collaborator Author

@frostedoyster frostedoyster Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also the first directory (the "day" directory) is technically "time", I'll make it clearer

dir = sorted(dir.iterdir())[-1]
# take the latest checkpoint. This cannot be done with
# `sorted` because some checkpoint files are named with
# the epoch number (e.g. `epoch_10.ckpt` would be before
# `epoch_8.ckpt`). We therefore sort by file creation time.
new_continue_from = str(
sorted(dir.glob("*.ckpt"), key=lambda f: f.stat().st_ctime)[-1]
)
else:
new_continue_from = None

return new_continue_from


def train_model(
options: Union[DictConfig, Dict],
output: str = "model.pt",
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def setup_logging(
stream_handler.setFormatter(formatter)
handlers.append(stream_handler)

if log_file:
if log_file and is_main_process():
log_file = check_file_extension(filename=log_file, extension=".log")
file_handler = logging.FileHandler(filename=str(log_file), encoding="utf-8")
file_handler.setFormatter(formatter)
Expand Down
46 changes: 45 additions & 1 deletion tests/cli/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from omegaconf import OmegaConf

from metatrain import RANDOM_SEED
from metatrain.cli.train import train_model
from metatrain.cli.train import _process_continue_from, train_model
from metatrain.utils.errors import ArchitectureError

from . import (
Expand Down Expand Up @@ -421,6 +421,50 @@ def test_continue(options, monkeypatch, tmp_path):
train_model(options, continue_from=MODEL_PATH_64_BIT)


def test_continue_auto(options, caplog, monkeypatch, tmp_path):
"""Test that continuing with the `auto` keyword results in
a continuation from the most recent checkpoint."""
monkeypatch.chdir(tmp_path)
shutil.copy(DATASET_PATH_QM9, "qm9_reduced_100.xyz")
caplog.set_level(logging.INFO)

# Make up an output directory with some checkpoints
true_checkpoint_dir = Path("outputs/2021-09-02/00-10-05")
true_checkpoint_dir.mkdir(parents=True, exist_ok=True)
# as well as some lower-priority checkpoints
fake_checkpoints_dirs = [
Path("outputs/2021-08-01/00-00-00"),
Path("outputs/2021-09-01/00-00-00"),
Path("outputs/2021-09-02/00-00-00"),
Path("outputs/2021-09-02/00-10-00"),
]
for fake_checkpoint_dir in fake_checkpoints_dirs:
fake_checkpoint_dir.mkdir(parents=True, exist_ok=True)

for i in range(1, 4):
shutil.copy(MODEL_PATH_64_BIT, true_checkpoint_dir / f"model_{i}.ckpt")
for fake_checkpoint_dir in fake_checkpoints_dirs:
shutil.copy(MODEL_PATH_64_BIT, fake_checkpoint_dir / f"model_{i}.ckpt")

train_model(options, continue_from=_process_continue_from("auto"))

assert "Loading checkpoint from" in caplog.text
assert str(true_checkpoint_dir) in caplog.text
assert "model_3.ckpt" in caplog.text


def test_continue_auto_no_outputs(options, caplog, monkeypatch, tmp_path):
"""Test that continuing with the `auto` keyword results in
training from scratch if `outputs/` is not present."""
monkeypatch.chdir(tmp_path)
shutil.copy(DATASET_PATH_QM9, "qm9_reduced_100.xyz")
caplog.set_level(logging.INFO)

train_model(options, continue_from=_process_continue_from("auto"))

assert "Loading checkpoint from" not in caplog.text


def test_continue_different_dataset(options, monkeypatch, tmp_path):
"""Test that continuing training from a checkpoint runs without an error raise
with a different dataset than the original."""
Expand Down