From 7b0202397c259feb8ddefc1ce149106b6e6fd679 Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Wed, 31 Jul 2024 12:06:17 -0400 Subject: [PATCH] Run black --- et_replay/comm/commsTraceParser.py | 76 ++++++++++++++++++++++-------- et_replay/comm/comms_utils.py | 2 +- et_replay/execution_trace.py | 10 ++-- et_replay/tools/comm_replay.py | 6 ++- 4 files changed, 69 insertions(+), 25 deletions(-) diff --git a/et_replay/comm/commsTraceParser.py b/et_replay/comm/commsTraceParser.py index 78914395..b0553edd 100644 --- a/et_replay/comm/commsTraceParser.py +++ b/et_replay/comm/commsTraceParser.py @@ -10,6 +10,7 @@ from et_replay.comm.comms_utils import commsArgs import logging + logger = logging.getLogger(__name__) tensorDtypeMap = { @@ -29,7 +30,11 @@ def parseTrace( - in_trace: List, trace_type: str, trace_file_path: str, target_rank: int, total_ranks: int + in_trace: List, + trace_type: str, + trace_file_path: str, + target_rank: int, + total_ranks: int, ) -> List: """ Parse trace files to be compatible with PARAM replay-mode. @@ -50,8 +55,10 @@ def parseTrace( ExecutionTrace(in_trace), target_rank, total_ranks ) else: - raise ValueError(f"Specified trace type {trace_type} to {trace_file_path} is not supported. \ -Please check supported types with '--help'") + raise ValueError( + f"Specified trace type {trace_type} to {trace_file_path} is not supported. \ +Please check supported types with '--help'" + ) return parsed_trace @@ -91,17 +98,26 @@ def _parseExecutionTrace( Convert the Execution Trace comms metadata to the common trace format for replay. """ ET_PG_NAME_TUPLE = in_trace.schema_pytorch() >= (1, 0, 3) - if (in_trace.schema_pytorch() < (1, 0, 3)): - raise ValueError(f"Only support trace version >1.0.3, but current trace version is {in_trace.schema.split('-')[0]}") + if in_trace.schema_pytorch() < (1, 0, 3): + raise ValueError( + f"Only support trace version >1.0.3, but current trace version is {in_trace.schema.split('-')[0]}" + ) - pg_ranks_map = _parse_proc_group_info(in_trace) # key is pg id, value is global ranks in this pg - comms_op_list = _parse_comms_op_node(in_trace, pg_ranks_map, target_rank, total_ranks) + pg_ranks_map = _parse_proc_group_info( + in_trace + ) # key is pg id, value is global ranks in this pg + comms_op_list = _parse_comms_op_node( + in_trace, pg_ranks_map, target_rank, total_ranks + ) return comms_op_list + def _parse_proc_group_info(in_trace: ExecutionTrace): - pg_ranks_map = {} # {node_id : {process_group_id : [ranks] } } - pg_init_nodes = (node for node in in_trace.nodes.values() if "process_group:init" in node.name) + pg_ranks_map = {} # {node_id : {process_group_id : [ranks] } } + pg_init_nodes = ( + node for node in in_trace.nodes.values() if "process_group:init" in node.name + ) for node in pg_init_nodes: # info of this node is dumped using torch.distributed.distributed_c10d._world.pg_config_info # at the start of profiling, but not not callback to torch.distributed.init_process_group() @@ -115,9 +131,13 @@ def _parse_proc_group_info(in_trace: ExecutionTrace): for pg in pg_objs: if not pg["pg_name"].isdecimal(): # TODO support local synchronization pg - logger.warning(f"Process group name is {pg['pg_name']} in node {node.id}, which is not supported. Skip.") + logger.warning( + f"Process group name is {pg['pg_name']} in node {node.id}, which is not supported. Skip." + ) continue - (pg_id, ranks, group_size, group_count) = [pg[k] for k in ["pg_name", "ranks", "group_size", "group_count"]] + (pg_id, ranks, group_size, group_count) = [ + pg[k] for k in ["pg_name", "ranks", "group_size", "group_count"] + ] pg_id = int(pg_id) pg_ranks_map[node.id][pg_id] = ( ranks @@ -128,7 +148,10 @@ def _parse_proc_group_info(in_trace: ExecutionTrace): break # only one process_group init node per trace return pg_ranks_map -def _parse_comms_op_node(in_trace: ExecutionTrace, pg_ranks_map: dict, target_rank: int, total_ranks: int): + +def _parse_comms_op_node( + in_trace: ExecutionTrace, pg_ranks_map: dict, target_rank: int, total_ranks: int +): comms_op_list = [] for node_id in pg_ranks_map: @@ -137,10 +160,12 @@ def _parse_comms_op_node(in_trace: ExecutionTrace, pg_ranks_map: dict, target_ra comms_op_list.append(comm_args) pg_ranks_map_flatten = {} - for k,v in pg_ranks_map.items(): + for k, v in pg_ranks_map.items(): pg_ranks_map_flatten.update(v) - comm_nodes = (node for node in in_trace.nodes.values() if node.name == "record_param_comms") + comm_nodes = ( + node for node in in_trace.nodes.values() if node.name == "record_param_comms" + ) for node in comm_nodes: # according to macro RECORD_PARAM_COMMS and RECORD_PARAM_COMMS_DATA in torch/csrc/distributed/c10d/ParamCommsUtils.hpp # ["wait", "barrier", "init"] record 1st element as seq, others record starting from input tensor @@ -150,13 +175,15 @@ def _parse_comms_op_node(in_trace: ExecutionTrace, pg_ranks_map: dict, target_ra comm_args = commsArgs() comm_args.id = node.id - comm_args.comms = comms_utils.paramToCommName(node.commArgs.collective_name.lower()) + comm_args.comms = comms_utils.paramToCommName( + node.commArgs.collective_name.lower() + ) if comm_args.comms == "init": # init node has been built continue comm_args.req = req_id - if (node.commArgs.pg_name and node.commArgs.pg_name.isdecimal()): + if node.commArgs.pg_name and node.commArgs.pg_name.isdecimal(): comm_args.pgId = int(node.commArgs.pg_name) comm_args.groupRanks = pg_ranks_map_flatten[comm_args.pgId] comm_args.worldSize = len(comm_args.groupRanks) @@ -171,9 +198,15 @@ def _parse_comms_op_node(in_trace: ExecutionTrace, pg_ranks_map: dict, target_ra # https://github.com/pytorch/pytorch/blob/6c4efd4e959017fc758fcc5dc32d8cc6a4b9164d/torch/distributed/distributed_c10d.py#L2404 if comm_args.comms in supportedP2pOps: if "send" in comm_args.comms: - (comm_args.src_rank, comm_args.dst_rank) = (target_rank, comm_args.groupRanks[recorded_rank]) + (comm_args.src_rank, comm_args.dst_rank) = ( + target_rank, + comm_args.groupRanks[recorded_rank], + ) elif "recv" in comm_args.comms: - (comm_args.src_rank, comm_args.dst_rank) = (comm_args.groupRanks[recorded_rank], target_rank) + (comm_args.src_rank, comm_args.dst_rank) = ( + comm_args.groupRanks[recorded_rank], + target_rank, + ) elif comm_args.comms in ["reduce", "broadcast", "gather", "scatter"]: comm_args.root = comm_args.groupRanks[recorded_rank] comm_args.groupRanks = comm_args.groupRanks @@ -185,17 +218,20 @@ def _parse_comms_op_node(in_trace: ExecutionTrace, pg_ranks_map: dict, target_ra comm_args.inSplit = ( json.loads(node.commArgs.in_split_size) if json.loads(node.commArgs.in_split_size) - else [int(comm_args.inMsgSize / comm_args.worldSize)] * comm_args.worldSize + else [int(comm_args.inMsgSize / comm_args.worldSize)] + * comm_args.worldSize ) comm_args.outSplit = ( json.loads(node.commArgs.out_split_size) if json.loads(node.commArgs.out_split_size) - else [int(comm_args.outMsgSize / comm_args.worldSize)] * comm_args.worldSize + else [int(comm_args.outMsgSize / comm_args.worldSize)] + * comm_args.worldSize ) comms_op_list.append(comm_args) return comms_op_list + def _create_pg_init_node(node_id: int, pg_id: int, ranks: List[int], world_size: int): comm_args = commsArgs() comm_args.id = node_id diff --git a/et_replay/comm/comms_utils.py b/et_replay/comm/comms_utils.py index dea4c3c1..b08e48c7 100644 --- a/et_replay/comm/comms_utils.py +++ b/et_replay/comm/comms_utils.py @@ -792,7 +792,7 @@ def __init__(self, supportedNwstacks: List[str] = None) -> None: "short": torch.short, "char": torch.int8, "signed char": torch.int8, - "unsigned char": torch.uint8 + "unsigned char": torch.uint8, } self.supportedDtype = list(self.dtypeMap.keys()) self.backendFuncs: BaseBackend diff --git a/et_replay/execution_trace.py b/et_replay/execution_trace.py index def8e3c1..b59c3f96 100644 --- a/et_replay/execution_trace.py +++ b/et_replay/execution_trace.py @@ -445,7 +445,7 @@ def _read_comm_attrs(cls, node: Dict[str, Any]) -> _CommArgs: if attr["name"] in cls.COMM_ATTR_TYPES.keys() } - params_dict = {k:attr_dict.get(k, None) for k in cls.COMM_ATTR_TYPES.keys()} + params_dict = {k: attr_dict.get(k, None) for k in cls.COMM_ATTR_TYPES.keys()} return _CommArgs(**params_dict) @staticmethod @@ -485,7 +485,11 @@ def _create_node_v1_0_2_chakra_0_0_4(pid, x: Dict[str, Any]) -> Node: kernel_file, ) = ExecutionTrace._read_attrs(x) - comm_attrs = ExecutionTrace._read_comm_attrs(x) if x['name'] == "record_param_comms" else None + comm_attrs = ( + ExecutionTrace._read_comm_attrs(x) + if x["name"] == "record_param_comms" + else None + ) return Node( x["name"], @@ -507,7 +511,7 @@ def _create_node_v1_0_2_chakra_0_0_4(pid, x: Dict[str, Any]) -> Node: rf_id, kernel_backend, kernel_file, - comm_attrs + comm_attrs, ) @staticmethod diff --git a/et_replay/tools/comm_replay.py b/et_replay/tools/comm_replay.py index 35daff98..8741edbb 100644 --- a/et_replay/tools/comm_replay.py +++ b/et_replay/tools/comm_replay.py @@ -1592,7 +1592,11 @@ def readTrace(self, remotePath: str, rank: int) -> None: self.comms_trace = commsTraceParser.parseTrace( self.comms_trace, self.trace_type, - (self.trace_file if not os.path.isdir(self.trace_file) else f"{self.trace_file}/{rank}.json"), + ( + self.trace_file + if not os.path.isdir(self.trace_file) + else f"{self.trace_file}/{rank}.json" + ), rank, self.backendFuncs.get_world_size(), )