From c33bcf273fbea129c5493bea1315a79edacabc6e Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 18 Oct 2024 06:40:21 +0200 Subject: [PATCH] Make it distributed-proof --- src/metatrain/__main__.py | 4 +++- src/metatrain/utils/logging.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/metatrain/__main__.py b/src/metatrain/__main__.py index 40b3eae8e..0aaacade0 100644 --- a/src/metatrain/__main__.py +++ b/src/metatrain/__main__.py @@ -16,6 +16,7 @@ export_model, ) from .cli.train import _add_train_model_parser, _prepare_train_model_args, train_model +from .utils.distributed.logging import is_main_process from .utils.logging import get_cli_input, setup_logging @@ -81,7 +82,8 @@ def main(): if callable == "train_model": # define and create `checkpoint_dir` based on current directory, date and time checkpoint_dir = _datetime_output_path(now=datetime.now()) - os.makedirs(checkpoint_dir, exist_ok=True) # exist_ok=True for distributed + if is_main_process(): + os.makedirs(checkpoint_dir) args.checkpoint_dir = checkpoint_dir log_file = checkpoint_dir / "train.log" diff --git a/src/metatrain/utils/logging.py b/src/metatrain/utils/logging.py index dbaf0356b..cf3801542 100644 --- a/src/metatrain/utils/logging.py +++ b/src/metatrain/utils/logging.py @@ -221,7 +221,7 @@ def setup_logging( stream_handler.setFormatter(formatter) handlers.append(stream_handler) - if log_file: + if log_file and is_main_process(): log_file = check_file_extension(filename=log_file, extension=".log") file_handler = logging.FileHandler(filename=str(log_file), encoding="utf-8") file_handler.setFormatter(formatter)