Skip to content

Commit

Permalink
Call processing function before new outputs/ directory is created
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Oct 18, 2024
1 parent f79a36d commit 3c136bf
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _add_train_model_parser(subparser: argparse._SubParsersAction) -> None:
"-c",
"--continue",
dest="continue_from",
type=str,
type=_process_continue_from,
required=False,
help="File to continue training from.",
)
Expand All @@ -98,7 +98,7 @@ def _prepare_train_model_args(args: argparse.Namespace) -> None:
args.options = OmegaConf.merge(args.options, override_options)


def _process_continue_from(continue_from: Optional[str]) -> Optional[str]:
def _process_continue_from(continue_from: str) -> Optional[str]:
# covers the case where `continue_from` is `auto`
if continue_from == "auto":
# try to find the `outputs` directory; if it doesn't exist
Expand All @@ -112,13 +112,13 @@ def _process_continue_from(continue_from: Optional[str]) -> Optional[str]:
# `sorted` because some checkpoint files are named with
# the epoch number (e.g. `epoch_10.ckpt` would be before
# `epoch_8.ckpt`). We therefore sort by file creation time.
continue_from = str(
new_continue_from = str(
sorted(dir.glob("*.ckpt"), key=lambda f: f.stat().st_ctime)[-1]
)
else:
continue_from = None
new_continue_from = None

return continue_from
return new_continue_from


def train_model(
Expand Down Expand Up @@ -357,7 +357,6 @@ def train_model(

logger.info("Setting up model")
try:
continue_from = _process_continue_from(continue_from)
if continue_from is not None:
logger.info(f"Loading checkpoint from `{continue_from}`")
trainer = Trainer.load_checkpoint(continue_from, hypers["training"])
Expand Down

0 comments on commit 3c136bf

Please sign in to comment.