diff --git a/et_replay/comm/commsTraceParser.py b/et_replay/comm/commsTraceParser.py index fb81dbbf..884de37f 100644 --- a/et_replay/comm/commsTraceParser.py +++ b/et_replay/comm/commsTraceParser.py @@ -2,6 +2,8 @@ from __future__ import annotations import json + +import logging from typing import List, Tuple from et_replay import ExecutionTrace @@ -9,200 +11,41 @@ from et_replay.comm.backend.base_backend import supportedP2pOps from et_replay.comm.comms_utils import commsArgs -tensorDtypeMap = { - "Tensor(int)": "int", - "Tensor(float)": "float", - "Tensor(bool)": "bool", - "Tensor(long)": "long", - "Tensor(long int)": "long", - "Tensor(double)": "double", - "Tensor(half)": "half", - "Tensor(byte)": "byte", - "Tensor(c10::Half)": "half", - "Tensor(c10::BFloat16)": "bfloat16", - "Tensor(unsigned char)": "char", - "Tensor(signed char)": "char", -} +logger = logging.getLogger(__name__) def parseTrace( - in_trace: List, trace_type: str, target_rank: int, total_ranks: int + in_trace: List, + trace_type: str, + trace_file_path: str, + target_rank: int, + total_ranks: int, ) -> List: """ Parse trace files to be compatible with PARAM replay-mode. - Currently supports: Basic Trace, Kineto Unitrace, and PyTorch ET trace. + Currently supports: Chakra host execution trace. Args: in_trace: Trace file to be parsed. trace_type: Trace type to be parsed with + trace_file_path: Path of input trace file being loaded. target_rank: The current rank of the device. total_ranks: Total number of ranks. Returns: parsed_trace: Parsed trace that is compatible with PARAM replay-mode. """ - if trace_type == "basic": # Basic Trace - parsed_trace = _parseBasicTrace(in_trace) - elif trace_type == "et": # Execution Trace (e.g. PyTorch ET, Chakra) + if trace_type == "et": # Execution Trace (e.g. Chakra host execution trace) parsed_trace = _parseExecutionTrace( ExecutionTrace(in_trace), target_rank, total_ranks ) - elif trace_type == "kineto": # Kineto Unitrace - parsed_trace = _parseKinetoUnitrace(in_trace, target_rank) - else: - raise ValueError("Unrecognized trace format.") - - return parsed_trace - - -def _parseBasicTrace(in_trace: List): - """ - Convert Basic Trace to comms trace format. - """ - newCommsTrace = [] - for cnt, curComm in enumerate(in_trace): - newComm = commsArgs() - newComm.id = cnt - newComm.markerStack = curComm.get("markers") - if "comms" in curComm: - _parseBasicTraceComms(curComm, newComm) - - elif "compute" in curComm: - _parseBasicTraceCompute(curComm, newComm) - - if newComm.comms is not None or newComm.compute is not None: - newCommsTrace.append(newComm) - else: - raise ValueError( - "Trace file contains an element that is not a supported in PARAM! Please format all elements as comms or compute for replay." - ) - - return newCommsTrace - - -def _parseBasicTraceComms(curComm, newComm: commsArgs) -> None: - newComm.comms = comms_utils.paramToCommName(curComm["comms"].lower()) - if newComm.markerStack is None: - newComm.markerStack = [newComm.comms] - newComm.req = curComm.get("req") - newComm.startTimeNs = curComm.get("startTime_ns") - newComm.worldSize = curComm.get("world_size") - newComm.root = curComm.get("root") - newComm.pgId = curComm.get("pg_id") - newComm.groupRanks = curComm.get("global_ranks") - - if newComm.comms not in ("wait", "barrier", "init", "batch_isend_irecv"): - newComm.inMsgSize = curComm["in_msg_size"] - newComm.outMsgSize = curComm["out_msg_size"] - newComm.dtype = curComm["dtype"].lower() - - if newComm.comms == "all_to_allv": - newComm.inSplit = curComm["in_split"] - newComm.outSplit = curComm["out_split"] - - if newComm.comms in supportedP2pOps: - newComm.src_rank = curComm["src_rank"] - newComm.dst_rank = curComm["dst_rank"] - newComm.batch_p2p = curComm["use_batch"] - - -def _parseBasicTraceCompute(curComm, newComm: commsArgs) -> None: - newComm.compute = curComm["compute"].lower() - if newComm.markerStack is None: - newComm.markerStack = [newComm.compute] - # count = number of times to call the compute kernel - if "count" in curComm: - newComm.count = curComm["count"] - # if no count is specified, assume 1 - else: - newComm.count = 1 - if newComm.compute == "gemm": - if "mm_dim" in curComm: - newComm.mm0_dim0 = curComm.get("mm_dim") - newComm.mm0_dim1 = curComm.get("mm_dim") - newComm.mm1_dim0 = curComm.get("mm_dim") - newComm.mm1_dim1 = curComm.get("mm_dim") - else: - newComm.mm0_dim0 = curComm.get("mm0_dim0") - newComm.mm0_dim1 = curComm.get("mm0_dim1") - newComm.mm1_dim0 = curComm.get("mm1_dim0") - newComm.mm1_dim1 = curComm.get("mm1_dim1") - newComm.dtype = curComm.get("dtype").lower() - elif newComm.compute == "emb_lookup": - if "direction" in curComm: - newComm.direction = curComm["direction"] - else: - newComm.direction = "forward" - newComm.emb_dim = curComm.get("emb_dim") - newComm.num_embs = curComm.get("num_embs") - newComm.batch_size = curComm.get("batch_size") - newComm.num_emb_tables_per_device = curComm.get("num_emb_tables") - newComm.num_emb_tables_batched = -1 - newComm.bag_size = curComm.get("bag_size") else: raise ValueError( - f"Trace file contains {str(newComm.compute)} compute element that is not supported in PARAM!" + f"Specified trace type {trace_type} to {trace_file_path} is not supported. \ +Please check supported types with '--help'" ) - -def _parseKinetoUnitrace(in_trace: List, target_rank: int) -> List: - """ - Convert the Kineto unitrace w/ comms metadata to the clean common trace format for replay. - """ - newCommsTrace = [] - commsCnt = 0 - for entry in in_trace: - # TODO: figure the current marker stack if present - marker = "unknown" - pass - - if ( - "name" in entry - and entry["name"] == "record_param_comms" - and entry["args"]["rank"] == target_rank - ): - newComm = commsArgs() - newComm.comms = comms_utils.paramToCommName(entry["args"]["comms"].lower()) - newComm.id = commsCnt - newComm.inMsgSize = entry["args"]["in_msg_size"] - newComm.outMsgSize = entry["args"]["out_msg_size"] - newComm.dtype = entry["args"]["dtype"].lower() - newComm.inSplit = entry["args"]["in_split"] - newComm.outSplit = entry["args"]["out_split"] - newComm.markerStack = marker - - newCommsTrace.append(newComm) - commsCnt += 1 - - return newCommsTrace - - -def _getTensorInfoFromPyTorchETEntry( - tensor_container: List, container_type: str -) -> Tuple[int, int, str]: - """ - Extract message size, tensor count, type from PyTorch ET entry inputs/outputs field. - NOTE: This format can be changed at anytime. TODO: When an extract/parsing tool is available in ATC, switch to it. - """ - list_count = container_type.count("GenericList") - tensors = [] - if list_count == 2: - # GenericList[GenericList[Tensor(), Tensor()]] - tensors = tensor_container[0][0] - dtype = container_type.replace("GenericList[", "").split(",", 1)[0] - elif list_count == 1: - # GenericList[Tensor()] - tensors = tensor_container[0] - dtype = container_type.replace("GenericList[", "").replace("]", "") - else: - tensors.append(tensor_container[0]) - dtype = container_type - - msg_size = 0 - for tensor in tensors: - msg_size += tensor[3] - - return msg_size, dtype + return parsed_trace def _parseExecutionTrace( @@ -210,148 +53,147 @@ def _parseExecutionTrace( ) -> List: """ Convert the Execution Trace comms metadata to the common trace format for replay. - """ - # Execution Trace PG_ID types availability - ET_PG_NAME_TUPLE = in_trace.schema_pytorch() >= (1, 0, 3) - ET_BACKENDID = in_trace.schema_pytorch() < (1, 0, 3) - - initOps = [] - newCommsTrace = [] - backendIdToPgid = {} - pgRanksMap = {} - groupCnt = -1 - - # Parse PG info from ET - for node in in_trace.nodes.values(): - if "process_group:init" in node.name: - pgJson = node.inputs[0] - try: - pgObj = json.loads(pgJson) - except json.decoder.JSONDecodeError: # skip if pg_config_info is truncated - break + if in_trace.schema_pytorch() < (1, 0, 3): + raise ValueError( + f"Only support trace version >1.0.3, but current trace version is {in_trace.schema.split('-')[0]}" + ) - for pg in pgObj: - if not pg["pg_name"].isdecimal(): - # TODO support local synchronization pg - continue - pgId = int(pg["pg_name"]) - ranks = pg["ranks"] - groupCnt = pg["group_count"] - pgRanksMap[pgId] = ( - ranks - if len(ranks) > 0 - else list(range(pg["group_size"])) - # rank list is empty when all ranks are in a pg + pg_ranks_map = _parse_proc_group_info( + in_trace + ) # key is pg id, value is global ranks in this pg + comms_op_list = _parse_comms_op_node( + in_trace, pg_ranks_map, target_rank, total_ranks + ) + + return comms_op_list + + +def _parse_proc_group_info(in_trace: ExecutionTrace): + pg_ranks_map = {} # {node_id : {process_group_id : [ranks] } } + pg_init_nodes = ( + node for node in in_trace.nodes.values() if "process_group:init" in node.name + ) + for node in pg_init_nodes: + # info of this node is dumped using torch.distributed.distributed_c10d._world.pg_config_info + # at the start of profiling, but not not callback to torch.distributed.init_process_group() + # Pre-Assumption: all process groups has been created before profiling start. + try: + pg_objs = json.loads(node.inputs[0]) + except json.decoder.JSONDecodeError: # skip if pg_config_info is truncated + break + + pg_ranks_map[node.id] = {} + for pg in pg_objs: + if not pg["pg_name"].isdecimal(): + # TODO support local synchronization pg + logger.warning( + f"Process group name is {pg['pg_name']} in node {node.id}, which is not supported. Skip." ) - if ET_BACKENDID: - backendId = pg["uid"] if "uid" in pg else pg["backend_id"] - backendIdToPgid[backendId] = pgId - break # only one process_group init node per trace - - # Parse comms nodes - for node in in_trace.nodes.values(): - if node.name == "record_param_comms": - shift = ( - 0 if len(node.inputs) == 8 or len(node.inputs) == 10 else 1 - ) # wait/barrier ops do not have an input tensor (len=7), shift index one over - newComm = commsArgs() - newComm.id = node.id - newComm.comms = comms_utils.paramToCommName( - node.inputs[4 - shift].lower() - ) # 5th value of inputs is colName - if newComm.comms == "init": continue - newComm.req = node.inputs[ - 1 - shift - ] # 2nd value of inputs is the req id of the collective - - pgIdentifier = node.inputs[ - 2 - shift - ] # 3rd value of inputs is the pg identifier of the collective - # Assign pg_id info for PGs that were created. - if ET_BACKENDID and pgIdentifier in backendIdToPgid: - newComm.pgId = backendIdToPgid[pgIdentifier] - newComm.groupRanks = pgRanksMap[newComm.pgId] - newComm.worldSize = len(newComm.groupRanks) - elif ET_PG_NAME_TUPLE and pgIdentifier[0].isdecimal(): - newComm.pgId = int(pgIdentifier[0]) - newComm.groupRanks = pgRanksMap[newComm.pgId] - newComm.worldSize = len(newComm.groupRanks) - - if newComm.comms not in ("wait", "barrier"): - ( - newComm.inMsgSize, - inMsgType, - ) = _getTensorInfoFromPyTorchETEntry(node.inputs, node.input_types[0]) - ( - newComm.outMsgSize, - _, - ) = _getTensorInfoFromPyTorchETEntry(node.outputs, node.output_types[0]) - newComm.dtype = tensorDtypeMap[ - inMsgType - ] # 1st value of input_types is the data type for the tensors - - if newComm.comms in supportedP2pOps: - if "send" in newComm.comms: - newComm.src_rank = target_rank - local_dst_rank = node.inputs[3 - shift] - newComm.dst_rank = newComm.groupRanks[local_dst_rank] - if "recv" in newComm.comms: - local_src_rank = node.inputs[3 - shift] - newComm.src_rank = newComm.groupRanks[local_src_rank] - newComm.dst_rank = target_rank - - if newComm.comms == "broadcast": - newComm.root = newComm.groupRanks[0] - newComm.srcOrDst = newComm.groupRanks[0] - - if newComm.comms == "all_to_allv": - # 6th value of inputs is in_split, split evenly if not provided - if not newComm.worldSize: - # if no pg info provided, use total ranks as world size - newComm.worldSize = total_ranks - newComm.inSplit = ( - node.inputs[5] - if node.inputs[5] - else [int(newComm.inMsgSize / newComm.worldSize)] - * newComm.worldSize + (pg_id, ranks, group_size, group_count) = [ + pg[k] for k in ["pg_name", "ranks", "group_size", "group_count"] + ] + pg_id = int(pg_id) + pg_ranks_map[node.id][pg_id] = ( + ranks + if len(ranks) > 0 + else list(range(group_size)) + # rank list is empty when all ranks are in a pg + ) + break # only one process_group init node per trace + return pg_ranks_map + + +def _parse_comms_op_node( # noqa: C901 + in_trace: ExecutionTrace, pg_ranks_map: dict, target_rank: int, total_ranks: int +): + comms_op_list = [] + + for node_id in pg_ranks_map: + for pg_id, ranks in pg_ranks_map[node_id].items(): + comm_args = _create_pg_init_node(node_id, pg_id, ranks, len(ranks)) + comms_op_list.append(comm_args) + + pg_ranks_map_flatten = {} + for _, v in pg_ranks_map.items(): + pg_ranks_map_flatten.update(v) + + comm_nodes = ( + node for node in in_trace.nodes.values() if node.name == "record_param_comms" + ) + for node in comm_nodes: + # according to macro RECORD_PARAM_COMMS and RECORD_PARAM_COMMS_DATA in torch/csrc/distributed/c10d/ParamCommsUtils.hpp + # ["wait", "barrier", "init"] record 1st element as seq, others record starting from input tensor + index_base = 0 if isinstance(node.inputs[0], int) else 1 + req_id = node.inputs[index_base] + recorded_rank = node.inputs[index_base + 2] + + comm_args = commsArgs() + comm_args.id = node.id + comm_args.comms = comms_utils.paramToCommName( + node.commArgs.collective_name.lower() + ) + if comm_args.comms == "init": + # init node has been built + continue + comm_args.req = req_id + + if node.commArgs.pg_name and node.commArgs.pg_name.isdecimal(): + comm_args.pgId = int(node.commArgs.pg_name) + comm_args.groupRanks = pg_ranks_map_flatten[comm_args.pgId] + comm_args.worldSize = len(comm_args.groupRanks) + + if comm_args.comms not in ("wait", "barrier"): + comm_args.inMsgSize = node.commArgs.in_msg_nelems + comm_args.outMsgSize = node.commArgs.out_msg_nelems + comm_args.dtype = node.commArgs.dtype.lower() + + # the recorded rank id in execution trace is local rank id in the process group + # we need to convert it to global rank for replay, check the function broadcast() of pytorch below: + # https://github.com/pytorch/pytorch/blob/6c4efd4e959017fc758fcc5dc32d8cc6a4b9164d/torch/distributed/distributed_c10d.py#L2404 + if comm_args.comms in supportedP2pOps: + if "send" in comm_args.comms: + (comm_args.src_rank, comm_args.dst_rank) = ( + target_rank, + comm_args.groupRanks[recorded_rank], ) - # 7th value of inputs is out_split, split evenly if not provided - newComm.outSplit = ( - node.inputs[6] - if node.inputs[6] - else [int(newComm.outMsgSize / newComm.worldSize)] - * newComm.worldSize + elif "recv" in comm_args.comms: + (comm_args.src_rank, comm_args.dst_rank) = ( + comm_args.groupRanks[recorded_rank], + target_rank, ) - newCommsTrace.append(newComm) - - # Build init node - initOps = [] - if groupCnt < 0: - # old format: To be removed - for pgId, ranks in pgRanksMap.items(): - newComm = create_pg_init_node(pgId, ranks, len(ranks)) - initOps.append(newComm) - else: - for pgId in range(groupCnt): - if pgId in pgRanksMap: - ranks = pgRanksMap[pgId] - else: - # create a dummy pg that the current rank is not part of - ranks = [0] if target_rank != 0 else [1] - - newComm = create_pg_init_node(pgId, ranks, len(ranks)) - initOps.append(newComm) + elif comm_args.comms in ["reduce", "broadcast", "gather", "scatter"]: + comm_args.root = comm_args.groupRanks[recorded_rank] + comm_args.groupRanks = comm_args.groupRanks + + if comm_args.comms == "all_to_allv": + if not comm_args.worldSize: + # if no pg info provided, use total ranks as world size + comm_args.worldSize = total_ranks + comm_args.inSplit = ( + json.loads(node.commArgs.in_split_size) + if json.loads(node.commArgs.in_split_size) + else [int(comm_args.inMsgSize / comm_args.worldSize)] + * comm_args.worldSize + ) + comm_args.outSplit = ( + json.loads(node.commArgs.out_split_size) + if json.loads(node.commArgs.out_split_size) + else [int(comm_args.outMsgSize / comm_args.worldSize)] + * comm_args.worldSize + ) + comms_op_list.append(comm_args) - return initOps + newCommsTrace + return comms_op_list -def create_pg_init_node(pg_id: int, ranks: List[int], world_size: int): - newComm = commsArgs() - newComm.comms = "init" - newComm.pgId = pg_id - newComm.req = -1 - newComm.groupRanks = ranks - newComm.worldSize = world_size - return newComm +def _create_pg_init_node(node_id: int, pg_id: int, ranks: List[int], world_size: int): + comm_args = commsArgs() + comm_args.id = node_id + comm_args.comms = "init" + comm_args.pgId = pg_id + comm_args.req = -1 + comm_args.groupRanks = ranks + comm_args.worldSize = world_size + return comm_args diff --git a/et_replay/comm/comms_utils.py b/et_replay/comm/comms_utils.py index 08a0201d..b08e48c7 100644 --- a/et_replay/comm/comms_utils.py +++ b/et_replay/comm/comms_utils.py @@ -791,6 +791,8 @@ def __init__(self, supportedNwstacks: List[str] = None) -> None: "int8": torch.int8, "short": torch.short, "char": torch.int8, + "signed char": torch.int8, + "unsigned char": torch.uint8, } self.supportedDtype = list(self.dtypeMap.keys()) self.backendFuncs: BaseBackend diff --git a/et_replay/execution_trace.py b/et_replay/execution_trace.py index cd4ad941..261ca8cc 100644 --- a/et_replay/execution_trace.py +++ b/et_replay/execution_trace.py @@ -112,7 +112,15 @@ class _CommArgs: collective_name: str dtype: str - # .. TODO add more see https://github.com/pytorch/pytorch/issues/124674 + in_msg_nelems: int + out_msg_nelems: int + in_split_size: str + out_split_size: str + global_rank_start: int + global_rank_stride: int + pg_name: str + pg_desc: str + pg_size: int """ @@ -402,13 +410,13 @@ def __init__(self, json): # remove all dataloader ops self.remove_dataloader_ops() - def _versiontuple(self, v: str) -> Tuple[int]: + def _versiontuple(self, v: str) -> Tuple[int, int, int]: return tuple(map(int, (v.split(".")))) - def schema_pytorch(self) -> Tuple[int]: + def schema_pytorch(self) -> Tuple[int, int, int]: return self._versiontuple(self.schema.split("-")[0]) - def schema_chakra(self) -> Tuple[int]: + def schema_chakra(self) -> Tuple[int, int, int]: if "-" not in self.schema: return (0, 0, 0) return self._versiontuple(self.schema.split("-")[1]) @@ -435,6 +443,32 @@ def _read_attrs(cls, node: Dict[str, Any]) -> Tuple: return tuple(attr_dict.get(key, None) for key in cls.ATTR_TYPES.keys()) + # MUST keep the order the same as members of _CommArgs + COMM_ATTR_TYPES = { + "collective_name": str, + "dtype": str, + "in_msg_nelems": int, + "out_msg_nelems": int, + "in_split_size": str, + "out_split_size": str, + "global_rank_start": int, + "global_rank_stride": int, + "pg_name": str, + "pg_desc": str, + "pg_size": int, + } + + @classmethod + def _read_comm_attrs(cls, node: Dict[str, Any]) -> _CommArgs: + attr_dict = { + attr["name"]: cls.COMM_ATTR_TYPES[attr["name"]](attr["value"]) + for attr in node["attrs"] + if attr["name"] in cls.COMM_ATTR_TYPES.keys() + } + + params_dict = {k: attr_dict.get(k, None) for k in cls.COMM_ATTR_TYPES.keys()} + return _CommArgs(**params_dict) + @staticmethod def _create_node_v1_0_1(pid, x: Dict[str, Any]) -> Node: return Node( @@ -472,6 +506,12 @@ def _create_node_v1_0_2_chakra_0_0_4(pid, x: Dict[str, Any]) -> Node: kernel_file, ) = ExecutionTrace._read_attrs(x) + comm_attrs = ( + ExecutionTrace._read_comm_attrs(x) + if x["name"] == "record_param_comms" + else None + ) + return Node( x["name"], x["id"], @@ -492,6 +532,7 @@ def _create_node_v1_0_2_chakra_0_0_4(pid, x: Dict[str, Any]) -> Node: rf_id, kernel_backend, kernel_file, + comm_attrs, ) @staticmethod @@ -508,6 +549,12 @@ def _create_node_v1_1_1_chakra_0_0_4(pid, x: Dict[str, Any]) -> Node: kernel_file, ) = ExecutionTrace._read_attrs(x) + comm_attrs = ( + ExecutionTrace._read_comm_attrs(x) + if x["name"] == "record_param_comms" + else None + ) + return Node( x["name"], x["id"], @@ -528,7 +575,7 @@ def _create_node_v1_1_1_chakra_0_0_4(pid, x: Dict[str, Any]) -> Node: rf_id, kernel_backend, kernel_file, - None, + comm_attrs, x["inputs"]["strides"], x["outputs"]["strides"], ) diff --git a/et_replay/tools/comm_replay.py b/et_replay/tools/comm_replay.py index 02352857..ff4f59a5 100644 --- a/et_replay/tools/comm_replay.py +++ b/et_replay/tools/comm_replay.py @@ -41,7 +41,8 @@ # sleep for 20ms to wait for next collective LOOP_TIMER_S = 0.02 -VALID_TRACE_TYPES = ["basic", "et", "kineto"] +# index 0 is default value of trace type +VALID_TRACE_TYPES = ["et"] def writeCommDetails(commsTracePerf: List, rank: int, folder: str = "./") -> None: @@ -176,8 +177,10 @@ def readArgs(self, parser: argparse.ArgumentParser) -> None: parser.add_argument( "--trace-type", type=str, - default="basic", - help=f"Trace type used for replay. Supported trace types: {str(VALID_TRACE_TYPES)}. By default use basic trace.", + choices=VALID_TRACE_TYPES, + default=VALID_TRACE_TYPES[0], + help=f"Select trace type used for replay. Supported trace types: {VALID_TRACE_TYPES}. \ + 'et' represents Chakra host execution trace.", ) parser.add_argument( "--use-one-trace", @@ -594,8 +597,9 @@ def hashEtCommsOp(self, commsOp: commsArgs) -> int: commsOp.pgId, commsOp.inMsgSize, commsOp.outMsgSize, - commsOp.inSplit, - commsOp.outSplit, + # inSplit and outSplit are list type, need to be converted for hash + tuple(commsOp.inSplit), + tuple(commsOp.outSplit), ) else: op = ( @@ -801,8 +805,8 @@ def runComms( self.collectiveArgs.collective = collName self.backendFuncs.P2POp(self.collectiveArgs, retFlag=True) - if collName in ["broadcast"]: - self.collectiveArgs.srcOrDst = curComm.srcOrDst + if collName in ["reduce", "broadcast", "gather", "scatter"]: + self.collectiveArgs.srcOrDst = curComm.root retObj = self.backendFuncs.collectiveFunc[collName]( self.collectiveArgs, retFlag=True @@ -1592,6 +1596,11 @@ def readTrace(self, remotePath: str, rank: int) -> None: self.comms_trace = commsTraceParser.parseTrace( self.comms_trace, self.trace_type, + ( + self.trace_file + if not os.path.isdir(self.trace_file) + else f"{self.trace_file}/{rank}.json" + ), rank, self.backendFuncs.get_world_size(), )