Skip to content

Commit

Permalink
remove function extractCommsInfo() for deprecatd basic trace
Browse files Browse the repository at this point in the history
  • Loading branch information
GSSBMW committed Oct 11, 2024
1 parent d86af8c commit 8a45a73
Showing 1 changed file with 12 additions and 63 deletions.
75 changes: 12 additions & 63 deletions et_replay/tools/comm_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch

from et_replay.comm import comms_utils
from et_replay.comm import commsTraceParser
from et_replay.comm.backend.base_backend import supportedP2pOps
from et_replay.comm.comms_utils import (
bootstrap_info_holder,
Expand Down Expand Up @@ -1547,69 +1548,17 @@ def readTrace(self, remotePath: str, rank: int) -> None:
self.readRawTrace(remotePath=remotePath, rank=rank)

# Convert trace to comms trace.
try:
from et_replay.comm import commsTraceParser
except ImportError:
logger.info("FB internals not present, using base parser.")
self.comms_trace = extractCommsInfo(self.comms_trace)
else:
self.comms_trace = commsTraceParser.parseTrace(
self.comms_trace,
self.trace_type,
(
self.trace_file
if not os.path.isdir(self.trace_file)
else f"{self.trace_file}/rank-{rank}.json"
),
rank,
self.backendFuncs.get_world_size(),
)


def extractCommsInfo(in_trace: List[Dict]) -> List[commsArgs]:
"""
Convert Basic Trace to comms trace format.
"""
# print("in extract comms info")
# exit(1)
newCommsTrace = []
for cnt, curComm in enumerate(in_trace):
newComm = commsArgs()
newComm.comms = paramToCommName(curComm["comms"].lower())
logger.info(f"in extract comms info of {newComm.comms}: {curComm}")
newComm.id = cnt
if "req" in curComm:
newComm.req = curComm["req"]
if "startTime_ns" in curComm:
newComm.startTimeNs = curComm["startTime_ns"]
if "markers" in curComm:
newComm.markerStack = curComm["markers"]
if "world_size" in curComm:
newComm.worldSize = curComm["world_size"]
if "root" in curComm:
newComm.root = curComm["root"]
if "pg_id" in curComm:
newComm.pgId = curComm["pg_id"]
if "global_ranks" in curComm:
newComm.groupRanks = curComm["global_ranks"]

if newComm.comms not in ("wait", "barrier", "init"):
newComm.inMsgSize = curComm["in_msg_size"]
newComm.outMsgSize = curComm["out_msg_size"]
newComm.dtype = curComm["dtype"]

if newComm.comms in ("all_to_allv"):
newComm.inSplit = curComm["in_split"]
newComm.outSplit = curComm["out_split"]

if newComm.comms in supportedP2pOps:
newComm.src_rank = curComm["src_rank"]
newComm.dst_rank = curComm["dst_rank"]
newComm.batch_p2p = curComm["use_batch"]

newCommsTrace.append(newComm)

return newCommsTrace
self.comms_trace = commsTraceParser.parseTrace(
self.comms_trace,
self.trace_type,
(
self.trace_file
if not os.path.isdir(self.trace_file)
else f"{self.trace_file}/rank-{rank}.json"
),
rank,
self.backendFuncs.get_world_size(),
)


def main() -> None:
Expand Down

0 comments on commit 8a45a73

Please sign in to comment.