Skip to content

Commit

Permalink
Add support to multiple process groups by syncing across ranks (#151)
Browse files Browse the repository at this point in the history
Summary:
Add support to multiple process groups by syncing across ranks.

Pull Request resolved: #151

Test Plan: /usr/local/fbcode/platform010/bin/mpirun -np 2 path-to/comm_replay.par --trace-path param_bench/fb/integration_tests/resnet-2gpu --trace-type et

Differential Revision: D60788539

Pulled By: shengfukevin
  • Loading branch information
Sergei-Lebedev authored and facebook-github-bot committed Aug 5, 2024
1 parent fba0236 commit 727f5fd
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 10 deletions.
1 change: 0 additions & 1 deletion et_replay/comm/backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def alloc_ones(
ipTensor = ipTensor * scaleFactor
return ipTensor

@abstractmethod
def noop(
self,
collectiveArgs: collectiveArgsHolder = None,
Expand Down
32 changes: 24 additions & 8 deletions et_replay/comm/backend/pytorch_dist_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import json
import logging
import os

from itertools import cycle
from time import sleep
from typing import List, Optional
Expand Down Expand Up @@ -1008,7 +1010,8 @@ def get_new_pg(self, group_ranks, backend):
ranks=group_ranks, backend=backend
)
else:
return dist.new_group(ranks=group_ranks, backend=backend)
pg = dist.new_group(ranks=group_ranks, backend=backend)
return pg

def tensor_list_to_numpy(self, tensorList):
if isinstance(tensorList, list):
Expand Down Expand Up @@ -1070,9 +1073,24 @@ def initialize_backend(
def initialize_groups(self, backend="gloo"):
groups = {}
world_size = self.get_world_size()
global_rank = self.get_global_rank()

# sync pgs across ranks to fix hang with multiple comm groups
# because new_group() functions requires that all processes in the main group enter,
# even if they are not going to be members of the group.
# Assumption: pg_name is unique and consistent for all ranks
sync_store = dist.PrefixStore("pg_sync_r", self.tcp_store)
sync_store.set(str(global_rank), json.dumps(self.commsParams.groupRanks))
torch.distributed.barrier()
group_ranks_sync = self.commsParams.groupRanks.copy()
for i in range(self.get_world_size()):
if i == global_rank:
continue
json_data = sync_store.get(str(i))
group_ranks_sync.update(json.loads(json_data))

# create additional groups
for pg_id, group_ranks in self.commsParams.groupRanks.items():
for pg_id, group_ranks in group_ranks_sync.items():
if (
len(group_ranks) > world_size
): # this means that --auto-shrink is enabled, only use default pg
Expand All @@ -1084,12 +1102,10 @@ def initialize_groups(self, backend="gloo"):
pg = self.get_default_group()
else:
pg = self.get_new_pg(group_ranks=group_ranks, backend=backend)
global_rank = self.get_global_rank()
if global_rank in group_ranks:
logger.info(
f"initialize_groups: Rank {global_rank} creates new group pg_id {pg_id} {pg} with {group_ranks}"
)
groups[pg_id] = pg
logger.info(
f"initialized_group: create new group: pg_id = {pg_id} group_ranks = {group_ranks}"
)
groups[int(pg_id)] = pg

# if additional groups are created, overwrite the default groups list
if len(groups):
Expand Down
5 changes: 4 additions & 1 deletion et_replay/tools/comm_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
from et_replay.comm.param_profile import paramProfile, paramTimer

try:
from trainer_iteration_wrapper import setTrainingIteration # @manual
# pyre-ignore[21]:
from trainer_iteration_wrapper import setTrainingIteration
except ImportError:
pass

Expand Down Expand Up @@ -89,6 +90,8 @@ def writeCommDetails(commsTracePerf: List, rank: int, folder: str = "./") -> Non
json.dump(commsTracePerf, write_file, indent=2)


# pyre-ignore[13]: lint complained about self.backendFuncs is never initlized.
# it is initialized in initBeckEnd
class commsTraceReplayBench(paramCommsBench):
"""
A class to replay and benchmark generated traces for collective communications.
Expand Down

0 comments on commit 727f5fd

Please sign in to comment.