diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index da80b425..6a730dcb 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -125,6 +125,12 @@ def __init__(self): default="tb", help="Folder to dump TensorBoard states", ) + self.parser.add_argument( + "--metrics.rank_0_only", + default=True, + action="store_true", + help="Whether to save TensorBoard metrics only for rank 0 or for all ranks", + ) # model configs self.parser.add_argument( diff --git a/torchtitan/metrics.py b/torchtitan/metrics.py index 90108976..b9b9cabd 100644 --- a/torchtitan/metrics.py +++ b/torchtitan/metrics.py @@ -113,16 +113,21 @@ def close(self): def build_metric_logger(config: JobConfig, tag: Optional[str] = None): dump_dir = config.job.dump_folder - save_tb_folder = config.metrics.save_tb_folder - # since we don't have run id yet, use current minute as identifier + tb_config = config.metrics + save_tb_folder = tb_config.save_tb_folder + # since we don't have run id, use current minute as the identifier datetime_str = datetime.now().strftime("%Y%m%d-%H%M") log_dir = os.path.join(dump_dir, save_tb_folder, datetime_str) - enable_tb = config.metrics.enable_tensorboard + enable_tb = tb_config.enable_tensorboard if enable_tb: logger.info( f"Metrics logging active. Tensorboard logs will be saved at {log_dir}" ) + if tb_config.rank_0_only: + enable_tb = torch.distributed.get_rank() == 0 + else: + rank_str = f"rank_{torch.distributed.get_rank()}" + log_dir = os.path.join(log_dir, rank_str) - rank_str = f"rank_{torch.distributed.get_rank()}" - return MetricLogger(os.path.join(log_dir, rank_str), tag, enable_tb) + return MetricLogger(log_dir, tag, enable_tb)