Skip to content

Commit

Permalink
Add support to multiple process groups by syncing across ranks (#151)
Browse files Browse the repository at this point in the history
Summary:
Add support to multiple process groups by syncing across ranks.

Pull Request resolved: #151

Test Plan: /usr/local/fbcode/platform010/bin/mpirun -np 2 path-to/comm_replay.par --trace-path param_bench/fb/integration_tests/resnet-2gpu --trace-type et

Reviewed By: briancoutinho

Differential Revision: D60788539

Pulled By: shengfukevin

fbshipit-source-id: a33932be1bc81e4865d027b071fa855a347d857b
  • Loading branch information
Sergei-Lebedev authored and facebook-github-bot committed Aug 7, 2024
1 parent fba0236 commit c466b60
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 9 deletions.
1 change: 0 additions & 1 deletion et_replay/comm/backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def alloc_ones(
ipTensor = ipTensor * scaleFactor
return ipTensor

@abstractmethod
def noop(
self,
collectiveArgs: collectiveArgsHolder = None,
Expand Down
36 changes: 29 additions & 7 deletions et_replay/comm/backend/pytorch_dist_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import json
import logging
import os

from itertools import cycle
from time import sleep
from typing import List, Optional
Expand Down Expand Up @@ -1008,7 +1010,8 @@ def get_new_pg(self, group_ranks, backend):
ranks=group_ranks, backend=backend
)
else:
return dist.new_group(ranks=group_ranks, backend=backend)
pg = dist.new_group(ranks=group_ranks, backend=backend)
return pg

def tensor_list_to_numpy(self, tensorList):
if isinstance(tensorList, list):
Expand Down Expand Up @@ -1070,9 +1073,30 @@ def initialize_backend(
def initialize_groups(self, backend="gloo"):
groups = {}
world_size = self.get_world_size()
global_rank = self.get_global_rank()

# sync pgs across ranks to fix hang with multiple comm groups
# because new_group() functions requires that all processes in the main group enter,
# even if they are not going to be members of the group.
# Assumption: pg_name is unique and consistent for all ranks
sync_store = dist.PrefixStore("pg_sync_r", self.tcp_store)
sync_store.set(str(global_rank), json.dumps(self.commsParams.groupRanks))
torch.distributed.barrier()
group_ranks_sync = self.commsParams.groupRanks.copy()
for i in range(self.get_world_size()):
if i == global_rank:
continue
json_data = sync_store.get(str(i))

# convert pg_id in json_data to int
pg_id2group_ranks = {
int(pg_id): rank for pg_id, rank in json.loads(json_data).items()
}

group_ranks_sync.update(pg_id2group_ranks)

# create additional groups
for pg_id, group_ranks in self.commsParams.groupRanks.items():
for pg_id, group_ranks in dict(sorted(group_ranks_sync.items())).items():
if (
len(group_ranks) > world_size
): # this means that --auto-shrink is enabled, only use default pg
Expand All @@ -1084,11 +1108,9 @@ def initialize_groups(self, backend="gloo"):
pg = self.get_default_group()
else:
pg = self.get_new_pg(group_ranks=group_ranks, backend=backend)
global_rank = self.get_global_rank()
if global_rank in group_ranks:
logger.info(
f"initialize_groups: Rank {global_rank} creates new group pg_id {pg_id} {pg} with {group_ranks}"
)
logger.info(
f"initialized_group: create new group, pg_id = {pg_id}, group_ranks = {group_ranks}"
)
groups[pg_id] = pg

# if additional groups are created, overwrite the default groups list
Expand Down
5 changes: 4 additions & 1 deletion et_replay/tools/comm_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
from et_replay.comm.param_profile import paramProfile, paramTimer

try:
from trainer_iteration_wrapper import setTrainingIteration # @manual
# pyre-ignore[21]:
from trainer_iteration_wrapper import setTrainingIteration
except ImportError:
pass

Expand Down Expand Up @@ -89,6 +90,8 @@ def writeCommDetails(commsTracePerf: List, rank: int, folder: str = "./") -> Non
json.dump(commsTracePerf, write_file, indent=2)


# pyre-ignore[13]: lint complained about self.backendFuncs is never initlized.
# it is initialized in initBackend
class commsTraceReplayBench(paramCommsBench):
"""
A class to replay and benchmark generated traces for collective communications.
Expand Down

0 comments on commit c466b60

Please sign in to comment.