diff --git a/src/metatrain/cli/train.py b/src/metatrain/cli/train.py index 5523b6ef..29de6b1a 100644 --- a/src/metatrain/cli/train.py +++ b/src/metatrain/cli/train.py @@ -98,6 +98,29 @@ def _prepare_train_model_args(args: argparse.Namespace) -> None: args.options = OmegaConf.merge(args.options, override_options) +def _process_continue_from(continue_from: Optional[str]) -> Optional[str]: + # covers the case where `continue_from` is `auto` + if continue_from == "auto": + # try to find the `outputs` directory; if it doesn't exist + # then we are not continuing from a previous run + if Path("outputs/").exists(): + # take the latest day directory + dir = sorted(Path("outputs/").iterdir())[-1] + # take the latest second directory + dir = sorted(dir.iterdir())[-1] + # take the latest checkpoint. This cannot be done with + # `sorted` because some checkpoint files are named with + # the epoch number (e.g. `epoch_10.ckpt` would be before + # `epoch_8.ckpt`). We therefore sort by file creation time. + continue_from = str( + sorted(dir.glob("*.ckpt"), key=lambda f: f.stat().st_ctime)[-1] + ) + else: + continue_from = None + + return continue_from + + def train_model( options: Union[DictConfig, Dict], output: str = "model.pt", @@ -334,6 +357,7 @@ def train_model( logger.info("Setting up model") try: + continue_from = _process_continue_from(continue_from) if continue_from is not None: logger.info(f"Loading checkpoint from `{continue_from}`") trainer = Trainer.load_checkpoint(continue_from, hypers["training"]) diff --git a/tests/cli/test_train_model.py b/tests/cli/test_train_model.py index a6abc91e..dc3abcd6 100644 --- a/tests/cli/test_train_model.py +++ b/tests/cli/test_train_model.py @@ -421,6 +421,52 @@ def test_continue(options, monkeypatch, tmp_path): train_model(options, continue_from=MODEL_PATH_64_BIT) +def test_continue_auto(options, caplog, monkeypatch, tmp_path): + """Test that continuing with the `auto` keyword results in + a continuation from the most recent checkpoint.""" + monkeypatch.chdir(tmp_path) + shutil.copy(DATASET_PATH_QM9, "qm9_reduced_100.xyz") + caplog.set_level(logging.INFO) + + # Make up an output directory with some checkpoints + true_checkpoint_dir = Path("outputs/2021-09-02/00-10-05") + true_checkpoint_dir.mkdir(parents=True, exist_ok=True) + # as well as some lower-priority checkpoints + fake_checkpoints_dirs = [ + Path("outputs/2021-08-01/00-00-00"), + Path("outputs/2021-09-01/00-00-00"), + Path("outputs/2021-09-02/00-00-00"), + Path("outputs/2021-09-02/00-10-00"), + ] + for fake_checkpoint_dir in fake_checkpoints_dirs: + fake_checkpoint_dir.mkdir(parents=True, exist_ok=True) + + for i in range(1, 4): + shutil.copy(MODEL_PATH_64_BIT, true_checkpoint_dir / f"model_{i}.ckpt") + # print content of the directory + print(f"true_checkpoint_dir content: {list(true_checkpoint_dir.iterdir())}") + for fake_checkpoint_dir in fake_checkpoints_dirs: + shutil.copy(MODEL_PATH_64_BIT, fake_checkpoint_dir / f"model_{i}.ckpt") + + train_model(options, continue_from="auto") + + assert "Loading checkpoint from" in caplog.text + assert str(true_checkpoint_dir) in caplog.text + assert "model_3.ckpt" in caplog.text + + +def test_continue_auto_no_outputs(options, caplog, monkeypatch, tmp_path): + """Test that continuing with the `auto` keyword results in + training from scratch if `outputs/` is not present.""" + monkeypatch.chdir(tmp_path) + shutil.copy(DATASET_PATH_QM9, "qm9_reduced_100.xyz") + caplog.set_level(logging.INFO) + + train_model(options, continue_from="auto") + + assert "Loading checkpoint from" not in caplog.text + + def test_continue_different_dataset(options, monkeypatch, tmp_path): """Test that continuing training from a checkpoint runs without an error raise with a different dataset than the original."""