Skip to content

Commit

Permalink
Run black
Browse files Browse the repository at this point in the history
  • Loading branch information
TaekyungHeo committed Aug 5, 2024
1 parent 7dcd0eb commit 7b02023
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 25 deletions.
76 changes: 56 additions & 20 deletions et_replay/comm/commsTraceParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from et_replay.comm.comms_utils import commsArgs

import logging

logger = logging.getLogger(__name__)

tensorDtypeMap = {
Expand All @@ -29,7 +30,11 @@


def parseTrace(
in_trace: List, trace_type: str, trace_file_path: str, target_rank: int, total_ranks: int
in_trace: List,
trace_type: str,
trace_file_path: str,
target_rank: int,
total_ranks: int,
) -> List:
"""
Parse trace files to be compatible with PARAM replay-mode.
Expand All @@ -50,8 +55,10 @@ def parseTrace(
ExecutionTrace(in_trace), target_rank, total_ranks
)
else:
raise ValueError(f"Specified trace type {trace_type} to {trace_file_path} is not supported. \
Please check supported types with '--help'")
raise ValueError(
f"Specified trace type {trace_type} to {trace_file_path} is not supported. \
Please check supported types with '--help'"
)

return parsed_trace

Expand Down Expand Up @@ -91,17 +98,26 @@ def _parseExecutionTrace(
Convert the Execution Trace comms metadata to the common trace format for replay.
"""
ET_PG_NAME_TUPLE = in_trace.schema_pytorch() >= (1, 0, 3)
if (in_trace.schema_pytorch() < (1, 0, 3)):
raise ValueError(f"Only support trace version >1.0.3, but current trace version is {in_trace.schema.split('-')[0]}")
if in_trace.schema_pytorch() < (1, 0, 3):
raise ValueError(
f"Only support trace version >1.0.3, but current trace version is {in_trace.schema.split('-')[0]}"
)

pg_ranks_map = _parse_proc_group_info(in_trace) # key is pg id, value is global ranks in this pg
comms_op_list = _parse_comms_op_node(in_trace, pg_ranks_map, target_rank, total_ranks)
pg_ranks_map = _parse_proc_group_info(
in_trace
) # key is pg id, value is global ranks in this pg
comms_op_list = _parse_comms_op_node(
in_trace, pg_ranks_map, target_rank, total_ranks
)

return comms_op_list


def _parse_proc_group_info(in_trace: ExecutionTrace):
pg_ranks_map = {} # {node_id : {process_group_id : [ranks] } }
pg_init_nodes = (node for node in in_trace.nodes.values() if "process_group:init" in node.name)
pg_ranks_map = {} # {node_id : {process_group_id : [ranks] } }
pg_init_nodes = (
node for node in in_trace.nodes.values() if "process_group:init" in node.name
)
for node in pg_init_nodes:
# info of this node is dumped using torch.distributed.distributed_c10d._world.pg_config_info
# at the start of profiling, but not not callback to torch.distributed.init_process_group()
Expand All @@ -115,9 +131,13 @@ def _parse_proc_group_info(in_trace: ExecutionTrace):
for pg in pg_objs:
if not pg["pg_name"].isdecimal():
# TODO support local synchronization pg
logger.warning(f"Process group name is {pg['pg_name']} in node {node.id}, which is not supported. Skip.")
logger.warning(
f"Process group name is {pg['pg_name']} in node {node.id}, which is not supported. Skip."
)
continue
(pg_id, ranks, group_size, group_count) = [pg[k] for k in ["pg_name", "ranks", "group_size", "group_count"]]
(pg_id, ranks, group_size, group_count) = [
pg[k] for k in ["pg_name", "ranks", "group_size", "group_count"]
]
pg_id = int(pg_id)
pg_ranks_map[node.id][pg_id] = (
ranks
Expand All @@ -128,7 +148,10 @@ def _parse_proc_group_info(in_trace: ExecutionTrace):
break # only one process_group init node per trace
return pg_ranks_map

def _parse_comms_op_node(in_trace: ExecutionTrace, pg_ranks_map: dict, target_rank: int, total_ranks: int):

def _parse_comms_op_node(
in_trace: ExecutionTrace, pg_ranks_map: dict, target_rank: int, total_ranks: int
):
comms_op_list = []

for node_id in pg_ranks_map:
Expand All @@ -137,10 +160,12 @@ def _parse_comms_op_node(in_trace: ExecutionTrace, pg_ranks_map: dict, target_ra
comms_op_list.append(comm_args)

pg_ranks_map_flatten = {}
for k,v in pg_ranks_map.items():
for k, v in pg_ranks_map.items():
pg_ranks_map_flatten.update(v)

comm_nodes = (node for node in in_trace.nodes.values() if node.name == "record_param_comms")
comm_nodes = (
node for node in in_trace.nodes.values() if node.name == "record_param_comms"
)
for node in comm_nodes:
# according to macro RECORD_PARAM_COMMS and RECORD_PARAM_COMMS_DATA in torch/csrc/distributed/c10d/ParamCommsUtils.hpp
# ["wait", "barrier", "init"] record 1st element as seq, others record starting from input tensor
Expand All @@ -150,13 +175,15 @@ def _parse_comms_op_node(in_trace: ExecutionTrace, pg_ranks_map: dict, target_ra

comm_args = commsArgs()
comm_args.id = node.id
comm_args.comms = comms_utils.paramToCommName(node.commArgs.collective_name.lower())
comm_args.comms = comms_utils.paramToCommName(
node.commArgs.collective_name.lower()
)
if comm_args.comms == "init":
# init node has been built
continue
comm_args.req = req_id

if (node.commArgs.pg_name and node.commArgs.pg_name.isdecimal()):
if node.commArgs.pg_name and node.commArgs.pg_name.isdecimal():
comm_args.pgId = int(node.commArgs.pg_name)
comm_args.groupRanks = pg_ranks_map_flatten[comm_args.pgId]
comm_args.worldSize = len(comm_args.groupRanks)
Expand All @@ -171,9 +198,15 @@ def _parse_comms_op_node(in_trace: ExecutionTrace, pg_ranks_map: dict, target_ra
# https://github.com/pytorch/pytorch/blob/6c4efd4e959017fc758fcc5dc32d8cc6a4b9164d/torch/distributed/distributed_c10d.py#L2404
if comm_args.comms in supportedP2pOps:
if "send" in comm_args.comms:
(comm_args.src_rank, comm_args.dst_rank) = (target_rank, comm_args.groupRanks[recorded_rank])
(comm_args.src_rank, comm_args.dst_rank) = (
target_rank,
comm_args.groupRanks[recorded_rank],
)
elif "recv" in comm_args.comms:
(comm_args.src_rank, comm_args.dst_rank) = (comm_args.groupRanks[recorded_rank], target_rank)
(comm_args.src_rank, comm_args.dst_rank) = (
comm_args.groupRanks[recorded_rank],
target_rank,
)
elif comm_args.comms in ["reduce", "broadcast", "gather", "scatter"]:
comm_args.root = comm_args.groupRanks[recorded_rank]
comm_args.groupRanks = comm_args.groupRanks
Expand All @@ -185,17 +218,20 @@ def _parse_comms_op_node(in_trace: ExecutionTrace, pg_ranks_map: dict, target_ra
comm_args.inSplit = (
json.loads(node.commArgs.in_split_size)
if json.loads(node.commArgs.in_split_size)
else [int(comm_args.inMsgSize / comm_args.worldSize)] * comm_args.worldSize
else [int(comm_args.inMsgSize / comm_args.worldSize)]
* comm_args.worldSize
)
comm_args.outSplit = (
json.loads(node.commArgs.out_split_size)
if json.loads(node.commArgs.out_split_size)
else [int(comm_args.outMsgSize / comm_args.worldSize)] * comm_args.worldSize
else [int(comm_args.outMsgSize / comm_args.worldSize)]
* comm_args.worldSize
)
comms_op_list.append(comm_args)

return comms_op_list


def _create_pg_init_node(node_id: int, pg_id: int, ranks: List[int], world_size: int):
comm_args = commsArgs()
comm_args.id = node_id
Expand Down
2 changes: 1 addition & 1 deletion et_replay/comm/comms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ def __init__(self, supportedNwstacks: List[str] = None) -> None:
"short": torch.short,
"char": torch.int8,
"signed char": torch.int8,
"unsigned char": torch.uint8
"unsigned char": torch.uint8,
}
self.supportedDtype = list(self.dtypeMap.keys())
self.backendFuncs: BaseBackend
Expand Down
10 changes: 7 additions & 3 deletions et_replay/execution_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def _read_comm_attrs(cls, node: Dict[str, Any]) -> _CommArgs:
if attr["name"] in cls.COMM_ATTR_TYPES.keys()
}

params_dict = {k:attr_dict.get(k, None) for k in cls.COMM_ATTR_TYPES.keys()}
params_dict = {k: attr_dict.get(k, None) for k in cls.COMM_ATTR_TYPES.keys()}
return _CommArgs(**params_dict)

@staticmethod
Expand Down Expand Up @@ -485,7 +485,11 @@ def _create_node_v1_0_2_chakra_0_0_4(pid, x: Dict[str, Any]) -> Node:
kernel_file,
) = ExecutionTrace._read_attrs(x)

comm_attrs = ExecutionTrace._read_comm_attrs(x) if x['name'] == "record_param_comms" else None
comm_attrs = (
ExecutionTrace._read_comm_attrs(x)
if x["name"] == "record_param_comms"
else None
)

return Node(
x["name"],
Expand All @@ -507,7 +511,7 @@ def _create_node_v1_0_2_chakra_0_0_4(pid, x: Dict[str, Any]) -> Node:
rf_id,
kernel_backend,
kernel_file,
comm_attrs
comm_attrs,
)

@staticmethod
Expand Down
6 changes: 5 additions & 1 deletion et_replay/tools/comm_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,7 +1592,11 @@ def readTrace(self, remotePath: str, rank: int) -> None:
self.comms_trace = commsTraceParser.parseTrace(
self.comms_trace,
self.trace_type,
(self.trace_file if not os.path.isdir(self.trace_file) else f"{self.trace_file}/{rank}.json"),
(
self.trace_file
if not os.path.isdir(self.trace_file)
else f"{self.trace_file}/{rank}.json"
),
rank,
self.backendFuncs.get_world_size(),
)
Expand Down

0 comments on commit 7b02023

Please sign in to comment.