Skip to content

Commit

Permalink
fix wait comm op for both collective and p2p
Browse files Browse the repository at this point in the history
  • Loading branch information
GSSBMW committed Oct 11, 2024
1 parent 8a45a73 commit 094a2ab
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 32 deletions.
4 changes: 2 additions & 2 deletions et_replay/comm/backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self) -> None:
self.global_rank = -1
self.backendFuncs = {}
self.collective = ""
self.collectiveId = 0
self.wait_obj_key = (0, 0, False) # (pg_id, req_id, is_p2p)
self.pt2pt = ""
self.src_rank = -1
self.dst_rank = -1
Expand Down Expand Up @@ -100,7 +100,7 @@ def __init__(self) -> None:
self.dataSize = 0
self.numElements = 0
self.waitObj = []
self.waitObjIds = {} # mapping of reqID to future of async collectives
self.waitObjIds = {} # mapping of (pg_id, req_id, is_p2p) to future of async collectives

self.ipTensor_split_pair = []
self.opTensor_split_pair = []
Expand Down
28 changes: 7 additions & 21 deletions et_replay/comm/backend/pytorch_dist_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,28 +566,14 @@ def complete_accel_ops(self, collectiveArgs, devSync=True):
if devSync:
self.device_sync(collectiveArgs)

# retFlag not used
def complete_single_op(self, collectiveArgs, retFlag=False):
"""only wait on the first op in the queue"""
if len(collectiveArgs.waitObj) > 0:
waitReq = collectiveArgs.waitObj.pop(0)
if waitReq is not None:
waitReq.wait()

# to ensure GPU collective is completed
self.device_sync(collectiveArgs)

def wait(self, collectiveArgs, retFlag=False):
# for backwards compatibility, use old wait functionality.
if len(collectiveArgs.waitObjIds) == 0:
self.complete_single_op(collectiveArgs)
return

"""wait on op with the matching reqID"""
if collectiveArgs.collectiveId in collectiveArgs.waitObjIds:
waitObj = collectiveArgs.waitObjIds[collectiveArgs.collectiveId]
if waitObj is not None:
waitObj.wait()
# wait on op with the matching (pg_id, req_id, is_p2p)
if collectiveArgs.wait_obj_key in collectiveArgs.waitObjIds:
work = collectiveArgs.waitObjIds.pop(collectiveArgs.wait_obj_key)
for i,w in enumerate(collectiveArgs.waitObj):
if w is work:
collectiveArgs.waitObj.pop(i)
work.wait()

def barrier(self, collectiveArgs, name="dummy", retFlag=False):
my_dev = self.get_device()
Expand Down
6 changes: 4 additions & 2 deletions et_replay/comm/commsTraceParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,12 @@ def _parse_comms_op_node( # noqa: C901
comm_nodes = (
node for node in in_trace.nodes.values() if node.name == "record_param_comms"
)
is_seq_id = lambda x: isinstance(x, list) and len(x) == 2 and isinstance(x[0], int) and isinstance(x[1], bool)
for node in comm_nodes:
# according to macro RECORD_PARAM_COMMS and RECORD_PARAM_COMMS_DATA in torch/csrc/distributed/c10d/ParamCommsUtils.hpp
# ["wait", "barrier", "init"] record 1st element as seq, others record starting from input tensor
index_base = 0 if isinstance(node.inputs[0], int) else 1
# ["wait", "barrier", "init"] record 1st element as seq_id, whose 1st element is an integer for sequence number, 2nd element is a bool for isP2P
# others record starting from input tensor
index_base = 0 if is_seq_id(node.inputs[0]) else 1
req_id = node.inputs[index_base]
recorded_rank = node.inputs[index_base + 2]

Expand Down
16 changes: 9 additions & 7 deletions et_replay/tools/comm_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(self):
self.shrink = False
self.max_msg_cnt = 0 # 0 means no limit
self.num_msg = 0
self.is_blocking = True
self.is_blocking = False
self.do_warm_up = False
self.reuse_tensors = False

Expand Down Expand Up @@ -802,9 +802,11 @@ def runComms(
description=f"# PARAM replay {self.replayIter}:" + curBlockStack,
):
if collName in self.backendFuncs.collectiveFunc.keys():
# record collectiveID for wait ops
if curComm.req is not None:
self.collectiveArgs.collectiveId = curComm.req
# record wait_obj_key for wait ops
if curComm.req is not None and curComm.pgId is not None:
self.collectiveArgs.wait_obj_key = (curComm.pgId, curComm.req[0], curComm.req[1])
else:
self.collectiveArgs.wait_obj_key = None

# handle point-to-point separately
if collName in supportedP2pOps:
Expand Down Expand Up @@ -832,10 +834,10 @@ def runComms(
if self.is_blocking:
self.backendFuncs.complete_accel_ops(self.collectiveArgs)

# if nonblocking, then store the pair {reqID, future} so that we can wait on it later
# if nonblocking, then store the pair {(pg_id, reqID, isP2P), future} so that we can wait on it later
# check if req id is recorded in trace for backwards compatibility
if curComm.req is not None and not self.is_blocking and collName != "wait":
self.collectiveArgs.waitObjIds[curComm.req] = retObj
if not self.is_blocking and collName != "wait" and self.collectiveArgs.wait_obj_key is not None:
self.collectiveArgs.waitObjIds[self.collectiveArgs.wait_obj_key] = retObj

# For non-blocking, latency and global_latency are the same
global_latency = latency = collTimer.getTimeUS()
Expand Down

0 comments on commit 094a2ab

Please sign in to comment.