Skip to content

Commit

Permalink
Automatic continuation
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Oct 17, 2024
1 parent d99721a commit cb1c66f
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
24 changes: 24 additions & 0 deletions src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,29 @@ def _prepare_train_model_args(args: argparse.Namespace) -> None:
args.options = OmegaConf.merge(args.options, override_options)


def _process_continue_from(continue_from: Optional[str]) -> Optional[str]:
# covers the case where `continue_from` is `auto`
if continue_from == "auto":
# try to find the `outputs` directory; if it doesn't exist
# then we are not continuing from a previous run
if Path("outputs/").exists():
# take the latest day directory
dir = sorted(Path("outputs/").iterdir())[-1]
# take the latest second directory
dir = sorted(dir.iterdir())[-1]
# take the latest checkpoint. This cannot be done with
# `sorted` because some checkpoint files are named with
# the epoch number (e.g. `epoch_10.ckpt` would be before
# `epoch_8.ckpt`). We therefore sort by file creation time.
continue_from = str(
sorted(dir.glob("*.ckpt"), key=lambda f: f.stat().st_ctime)[-1]
)
else:
continue_from = None

return continue_from


def train_model(
options: Union[DictConfig, Dict],
output: str = "model.pt",
Expand Down Expand Up @@ -334,6 +357,7 @@ def train_model(

logger.info("Setting up model")
try:
continue_from = _process_continue_from(continue_from)
if continue_from is not None:
logger.info(f"Loading checkpoint from `{continue_from}`")
trainer = Trainer.load_checkpoint(continue_from, hypers["training"])
Expand Down
46 changes: 46 additions & 0 deletions tests/cli/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,52 @@ def test_continue(options, monkeypatch, tmp_path):
train_model(options, continue_from=MODEL_PATH_64_BIT)


def test_continue_auto(options, caplog, monkeypatch, tmp_path):
"""Test that continuing with the `auto` keyword results in
a continuation from the most recent checkpoint."""
monkeypatch.chdir(tmp_path)
shutil.copy(DATASET_PATH_QM9, "qm9_reduced_100.xyz")
caplog.set_level(logging.INFO)

# Make up an output directory with some checkpoints
true_checkpoint_dir = Path("outputs/2021-09-02/00-10-05")
true_checkpoint_dir.mkdir(parents=True, exist_ok=True)
# as well as some lower-priority checkpoints
fake_checkpoints_dirs = [
Path("outputs/2021-08-01/00-00-00"),
Path("outputs/2021-09-01/00-00-00"),
Path("outputs/2021-09-02/00-00-00"),
Path("outputs/2021-09-02/00-10-00"),
]
for fake_checkpoint_dir in fake_checkpoints_dirs:
fake_checkpoint_dir.mkdir(parents=True, exist_ok=True)

for i in range(1, 4):
shutil.copy(MODEL_PATH_64_BIT, true_checkpoint_dir / f"model_{i}.ckpt")
# print content of the directory
print(f"true_checkpoint_dir content: {list(true_checkpoint_dir.iterdir())}")
for fake_checkpoint_dir in fake_checkpoints_dirs:
shutil.copy(MODEL_PATH_64_BIT, fake_checkpoint_dir / f"model_{i}.ckpt")

train_model(options, continue_from="auto")

assert "Loading checkpoint from" in caplog.text
assert str(true_checkpoint_dir) in caplog.text
assert "model_3.ckpt" in caplog.text


def test_continue_auto_no_outputs(options, caplog, monkeypatch, tmp_path):
"""Test that continuing with the `auto` keyword results in
training from scratch if `outputs/` is not present."""
monkeypatch.chdir(tmp_path)
shutil.copy(DATASET_PATH_QM9, "qm9_reduced_100.xyz")
caplog.set_level(logging.INFO)

train_model(options, continue_from="auto")

assert "Loading checkpoint from" not in caplog.text


def test_continue_different_dataset(options, monkeypatch, tmp_path):
"""Test that continuing training from a checkpoint runs without an error raise
with a different dataset than the original."""
Expand Down

0 comments on commit cb1c66f

Please sign in to comment.