Skip to content

Commit

Permalink
Make it distributed-proof
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Oct 18, 2024
1 parent 3c136bf commit 427d3d8
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
4 changes: 3 additions & 1 deletion src/metatrain/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
export_model,
)
from .cli.train import _add_train_model_parser, _prepare_train_model_args, train_model
from .utils.distributed.logging import is_main_process
from .utils.logging import get_cli_input, setup_logging


Expand Down Expand Up @@ -81,7 +82,8 @@ def main():
if callable == "train_model":
# define and create `checkpoint_dir` based on current directory, date and time
checkpoint_dir = _datetime_output_path(now=datetime.now())
os.makedirs(checkpoint_dir, exist_ok=True) # exist_ok=True for distributed
if is_main_process():
os.makedirs(checkpoint_dir)
args.checkpoint_dir = checkpoint_dir

log_file = checkpoint_dir / "train.log"
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def setup_logging(
stream_handler.setFormatter(formatter)
handlers.append(stream_handler)

if log_file:
if log_file and is_main_process():
log_file = check_file_extension(filename=log_file, extension=".log")
file_handler = logging.FileHandler(filename=str(log_file), encoding="utf-8")
file_handler.setFormatter(formatter)
Expand Down
6 changes: 3 additions & 3 deletions tests/cli/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from omegaconf import OmegaConf

from metatrain import RANDOM_SEED
from metatrain.cli.train import train_model
from metatrain.cli.train import _process_continue_from, train_model
from metatrain.utils.errors import ArchitectureError

from . import (
Expand Down Expand Up @@ -446,7 +446,7 @@ def test_continue_auto(options, caplog, monkeypatch, tmp_path):
for fake_checkpoint_dir in fake_checkpoints_dirs:
shutil.copy(MODEL_PATH_64_BIT, fake_checkpoint_dir / f"model_{i}.ckpt")

train_model(options, continue_from="auto")
train_model(options, continue_from=_process_continue_from("auto"))

assert "Loading checkpoint from" in caplog.text
assert str(true_checkpoint_dir) in caplog.text
Expand All @@ -460,7 +460,7 @@ def test_continue_auto_no_outputs(options, caplog, monkeypatch, tmp_path):
shutil.copy(DATASET_PATH_QM9, "qm9_reduced_100.xyz")
caplog.set_level(logging.INFO)

train_model(options, continue_from="auto")
train_model(options, continue_from=_process_continue_from("auto"))

assert "Loading checkpoint from" not in caplog.text

Expand Down

0 comments on commit 427d3d8

Please sign in to comment.