diff --git a/et_replay/tools/comm_replay.py b/et_replay/tools/comm_replay.py index c4fd3879..289d25cb 100644 --- a/et_replay/tools/comm_replay.py +++ b/et_replay/tools/comm_replay.py @@ -18,6 +18,7 @@ import torch from et_replay.comm import comms_utils +from et_replay.comm import commsTraceParser from et_replay.comm.backend.base_backend import supportedP2pOps from et_replay.comm.comms_utils import ( bootstrap_info_holder, @@ -1547,69 +1548,17 @@ def readTrace(self, remotePath: str, rank: int) -> None: self.readRawTrace(remotePath=remotePath, rank=rank) # Convert trace to comms trace. - try: - from et_replay.comm import commsTraceParser - except ImportError: - logger.info("FB internals not present, using base parser.") - self.comms_trace = extractCommsInfo(self.comms_trace) - else: - self.comms_trace = commsTraceParser.parseTrace( - self.comms_trace, - self.trace_type, - ( - self.trace_file - if not os.path.isdir(self.trace_file) - else f"{self.trace_file}/rank-{rank}.json" - ), - rank, - self.backendFuncs.get_world_size(), - ) - - -def extractCommsInfo(in_trace: List[Dict]) -> List[commsArgs]: - """ - Convert Basic Trace to comms trace format. - """ - # print("in extract comms info") - # exit(1) - newCommsTrace = [] - for cnt, curComm in enumerate(in_trace): - newComm = commsArgs() - newComm.comms = paramToCommName(curComm["comms"].lower()) - logger.info(f"in extract comms info of {newComm.comms}: {curComm}") - newComm.id = cnt - if "req" in curComm: - newComm.req = curComm["req"] - if "startTime_ns" in curComm: - newComm.startTimeNs = curComm["startTime_ns"] - if "markers" in curComm: - newComm.markerStack = curComm["markers"] - if "world_size" in curComm: - newComm.worldSize = curComm["world_size"] - if "root" in curComm: - newComm.root = curComm["root"] - if "pg_id" in curComm: - newComm.pgId = curComm["pg_id"] - if "global_ranks" in curComm: - newComm.groupRanks = curComm["global_ranks"] - - if newComm.comms not in ("wait", "barrier", "init"): - newComm.inMsgSize = curComm["in_msg_size"] - newComm.outMsgSize = curComm["out_msg_size"] - newComm.dtype = curComm["dtype"] - - if newComm.comms in ("all_to_allv"): - newComm.inSplit = curComm["in_split"] - newComm.outSplit = curComm["out_split"] - - if newComm.comms in supportedP2pOps: - newComm.src_rank = curComm["src_rank"] - newComm.dst_rank = curComm["dst_rank"] - newComm.batch_p2p = curComm["use_batch"] - - newCommsTrace.append(newComm) - - return newCommsTrace + self.comms_trace = commsTraceParser.parseTrace( + self.comms_trace, + self.trace_type, + ( + self.trace_file + if not os.path.isdir(self.trace_file) + else f"{self.trace_file}/rank-{rank}.json" + ), + rank, + self.backendFuncs.get_world_size(), + ) def main() -> None: