diff --git a/et_replay/comm/backend/base_backend.py b/et_replay/comm/backend/base_backend.py index 75de50bc..1e367dca 100644 --- a/et_replay/comm/backend/base_backend.py +++ b/et_replay/comm/backend/base_backend.py @@ -174,7 +174,6 @@ def alloc_ones( ipTensor = ipTensor * scaleFactor return ipTensor - @abstractmethod def noop( self, collectiveArgs: collectiveArgsHolder = None, diff --git a/et_replay/comm/backend/pytorch_dist_backend.py b/et_replay/comm/backend/pytorch_dist_backend.py index cfd75068..8de26e35 100644 --- a/et_replay/comm/backend/pytorch_dist_backend.py +++ b/et_replay/comm/backend/pytorch_dist_backend.py @@ -3,8 +3,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import json import logging import os + from itertools import cycle from time import sleep from typing import List, Optional @@ -1008,7 +1010,8 @@ def get_new_pg(self, group_ranks, backend): ranks=group_ranks, backend=backend ) else: - return dist.new_group(ranks=group_ranks, backend=backend) + pg = dist.new_group(ranks=group_ranks, backend=backend) + return pg def tensor_list_to_numpy(self, tensorList): if isinstance(tensorList, list): @@ -1070,9 +1073,29 @@ def initialize_backend( def initialize_groups(self, backend="gloo"): groups = {} world_size = self.get_world_size() + global_rank = self.get_global_rank() + + # sync pgs across ranks to fix hang with multiple comm groups + # because new_group() functions requires that all processes in the main group enter, + # even if they are not going to be members of the group. + # Assumption: pg_name is unique and consistent for all ranks + sync_store = dist.PrefixStore("pg_sync_r", self.tcp_store) + sync_store.set(str(global_rank), json.dumps(self.commsParams.groupRanks)) + torch.distributed.barrier() + group_ranks_sync = self.commsParams.groupRanks.copy() + for i in range(self.get_world_size()): + if i == global_rank: + continue + json_data = sync_store.get(str(i)) + + # convert pg_id in json_data to int + pg_id2group_ranks = {} + for pg_id, group_ranks in json.loads(json_data).items(): + pg_id2group_ranks[int(pg_id)] = group_ranks + group_ranks_sync.update(pg_id2group_ranks) # create additional groups - for pg_id, group_ranks in self.commsParams.groupRanks.items(): + for pg_id, group_ranks in dict(sorted(group_ranks_sync.items())).items(): if ( len(group_ranks) > world_size ): # this means that --auto-shrink is enabled, only use default pg @@ -1084,11 +1107,9 @@ def initialize_groups(self, backend="gloo"): pg = self.get_default_group() else: pg = self.get_new_pg(group_ranks=group_ranks, backend=backend) - global_rank = self.get_global_rank() - if global_rank in group_ranks: - logger.info( - f"initialize_groups: Rank {global_rank} creates new group pg_id {pg_id} {pg} with {group_ranks}" - ) + logger.info( + f"initialized_group: create new group, pg_id = {pg_id}, group_ranks = {group_ranks}" + ) groups[pg_id] = pg # if additional groups are created, overwrite the default groups list diff --git a/et_replay/tools/comm_replay.py b/et_replay/tools/comm_replay.py index 96b22c88..d35e28ac 100644 --- a/et_replay/tools/comm_replay.py +++ b/et_replay/tools/comm_replay.py @@ -31,7 +31,8 @@ from et_replay.comm.param_profile import paramProfile, paramTimer try: - from trainer_iteration_wrapper import setTrainingIteration # @manual + # pyre-ignore[21]: + from trainer_iteration_wrapper import setTrainingIteration except ImportError: pass @@ -89,6 +90,8 @@ def writeCommDetails(commsTracePerf: List, rank: int, folder: str = "./") -> Non json.dump(commsTracePerf, write_file, indent=2) +# pyre-ignore[13]: lint complained about self.backendFuncs is never initlized. +# it is initialized in initBeckEnd class commsTraceReplayBench(paramCommsBench): """ A class to replay and benchmark generated traces for collective communications.