Skip to content

Commit

Permalink
add an init-only mode to benchmark initialization alone
Browse files Browse the repository at this point in the history
Summary:
as $title
Add an init only mode to param benchmark to measure nccl initialization time along. This requires NCCL communicator being initialized in eager mode(not lazy mode that the initialization is trigger by the first collective)

Reviewed By: cenzhaometa

Differential Revision: D56767297

fbshipit-source-id: c4d70540d3f9dc007e2b1a51b08a477da2ca8938
  • Loading branch information
shengbao-zheng authored and facebook-github-bot committed May 1, 2024
1 parent 425e08a commit c83ce84
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 8 deletions.
20 changes: 15 additions & 5 deletions train/comms/pt/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import argparse
import logging
import time
from inspect import signature

import numpy as np

Expand Down Expand Up @@ -1788,11 +1789,20 @@ def initBackend(
comms_utils.gracefulExit()

self.backendFuncs = backendObj
self.backendFuncs.initialize_backend(
bootstrap_info.master_ip,
bootstrap_info.master_port,
backend=commsParams.backend,
)
sig = signature(self.backendFuncs.initialize_backend)
if "eager_mode" in sig.parameters:
self.backendFuncs.initialize_backend(
bootstrap_info.master_ip,
bootstrap_info.master_port,
backend=commsParams.backend,
eager_mode=commsParams.init_only,
)
else:
self.backendFuncs.initialize_backend(
bootstrap_info.master_ip,
bootstrap_info.master_port,
backend=commsParams.backend,
)
self.backendFuncs.sayHello() # Informs us where each process is running.

def runBench(self, commsParams):
Expand Down
7 changes: 7 additions & 0 deletions train/comms/pt/comms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,7 @@ def __init__(self, args: Namespace) -> None:
self.enable_profiler = args.enable_profiler
self.use_perf_logger = args.use_perf_logger
self.ibv_devices = args.ibv_devices
self.init_only = args.init_only


class commsDlrmParamsHolder(commsParamsHolderBase):
Expand Down Expand Up @@ -1635,6 +1636,12 @@ def readArgs(self, parser: ArgumentParser) -> None:
default="",
help="list of ib devices to use for distributed communication",
) # experimental feature
parser.add_argument(
"--init-only",
action="store_true",
default=False,
help="Toggle to skip running collectives and only do initalization",
)
pass

@abstractmethod
Expand Down
21 changes: 18 additions & 3 deletions train/comms/pt/pytorch_dist_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import os
from itertools import cycle
from time import sleep
from typing import List, Optional

import numpy as np
Expand Down Expand Up @@ -1033,10 +1034,16 @@ def initialize_tcpstore(self, master_ip, master_port):
global_rank = self.bootstrap_info.global_rank
world_size = self.bootstrap_info.world_size
self.tcp_store = dist.TCPStore(
master_ip, int(master_port), world_size, is_master=(global_rank == 0)
master_ip,
int(master_port),
world_size,
is_master=(global_rank == 0),
use_libuv=True,
)

def initialize_backend(self, master_ip, master_port, backend="gloo"):
def initialize_backend(
self, master_ip, master_port, backend="gloo", eager_mode=False
):
# Set CUDA device before initializing backend
# Required for backends that don't do lazy initialization, e.g. UCC
self.set_device(self.bootstrap_info.local_rank, self.bootstrap_info.global_rank)
Expand Down Expand Up @@ -1067,6 +1074,11 @@ def initialize_backend(self, master_ip, master_port, backend="gloo"):
world_size=world_size,
store=self.tcp_store if self.commsParams.init_method is None else None,
init_method=self.commsParams.init_method,
device_id=(
torch.device(f"cuda:{self.bootstrap_info.local_rank}")
if eager_mode
else None
),
)

# default 1 group, maybe overwritten by user created groups via initialize_groups
Expand Down Expand Up @@ -1109,7 +1121,10 @@ def initialize_groups(self, backend="gloo"):

def benchmark_comms(self, benchTime, commsParams):
index = 0 # used in TPU, where it is not initialized!
benchTime(index, commsParams, self)
if commsParams.init_only:
sleep(10)
else:
benchTime(index, commsParams, self)
return

def __del__(self):
Expand Down

0 comments on commit c83ce84

Please sign in to comment.