Skip to content

Commit

Permalink
Make it distributed-proof
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Oct 18, 2024
1 parent 3c136bf commit c33bcf2
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion src/metatrain/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
export_model,
)
from .cli.train import _add_train_model_parser, _prepare_train_model_args, train_model
from .utils.distributed.logging import is_main_process
from .utils.logging import get_cli_input, setup_logging


Expand Down Expand Up @@ -81,7 +82,8 @@ def main():
if callable == "train_model":
# define and create `checkpoint_dir` based on current directory, date and time
checkpoint_dir = _datetime_output_path(now=datetime.now())
os.makedirs(checkpoint_dir, exist_ok=True) # exist_ok=True for distributed
if is_main_process():
os.makedirs(checkpoint_dir)
args.checkpoint_dir = checkpoint_dir

log_file = checkpoint_dir / "train.log"
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def setup_logging(
stream_handler.setFormatter(formatter)
handlers.append(stream_handler)

if log_file:
if log_file and is_main_process():
log_file = check_file_extension(filename=log_file, extension=".log")
file_handler = logging.FileHandler(filename=str(log_file), encoding="utf-8")
file_handler.setFormatter(formatter)
Expand Down

0 comments on commit c33bcf2

Please sign in to comment.