From f064d8390e8aa77362207f3382430d55ef8cf568 Mon Sep 17 00:00:00 2001 From: Sheng Fu Date: Mon, 24 Jun 2024 17:59:13 -0700 Subject: [PATCH] Compute node analysis (#127) Summary: Pull Request resolved: https://github.com/facebookresearch/param/pull/127 Compute node analysis Differential Revision: D58610032 --- et_replay/lib/et_replay_utils.py | 3 +- et_replay/tools/et_replay.py | 271 ++++++++++++++++--------------- 2 files changed, 138 insertions(+), 136 deletions(-) diff --git a/et_replay/lib/et_replay_utils.py b/et_replay/lib/et_replay_utils.py index 5159d40b..97e3ec6c 100644 --- a/et_replay/lib/et_replay_utils.py +++ b/et_replay/lib/et_replay_utils.py @@ -464,8 +464,7 @@ def build_triton_func(n, resources_dir, async_compile, device): func = async_compile.triton(n.name, code, device_str=device) except Exception: func = async_compile.triton("triton_", code, device_str=device) - finally: - func = None + return func, 0 diff --git a/et_replay/tools/et_replay.py b/et_replay/tools/et_replay.py index 423cd394..8ba83b90 100644 --- a/et_replay/tools/et_replay.py +++ b/et_replay/tools/et_replay.py @@ -116,8 +116,8 @@ def __init__(self): # Skip the node if their names contain any of the following strings. self.skip_node_names = [ "DataLoader", - "aten::set_", ] + self.skip_node_ids = [] self.parallel_nodes_parents = [] # Ids of nodes that need to run in parallel. @@ -126,7 +126,7 @@ def __init__(self): # This is used to pick out a single iteration when trace contains multiple iterations. # Basically this label should be captured at the beginning of each iteration so that one iteration # is between two consecutive label nodes. - self.label = "" + self.profile_step_label = "ProfilerStep#" try: from param_bench.et_replay.lib.fb.internals import ( @@ -141,7 +141,9 @@ def __init__(self): self.parallel_nodes_parents = add_internal_parallel_nodes_parents( self.parallel_nodes_parents ) - self.label = add_internal_label() + self.profile_step_label = add_internal_label() + + self.profile_step_node_ids = [] # Only use for memory profile. self.current_allocated_mem = 0 @@ -159,9 +161,10 @@ def __init__(self): self.additional_tensors = set() self.additional_tensors_size = 0 - # Debug use, record the nodes we skip. - self.actual_skip_nodes = [] - self.actual_skip_nodes_cnt = 0 + # Debug use, record the nodes we skip, + # it is in the following format: + # {"node_name" : [{"node_id" : "skip reason"}, ...]} + self.actual_skip_nodes = {} self.tensor_with_device = True # A tensor may appear on multiple devices but here we only store the first device for initialization @@ -311,6 +314,14 @@ def reset_registry(self): gc.collect() torch.cuda.empty_cache() + def add_skipped_nodes(self, node, reason): + # "detach" is only needed for autograd, does not affect execution + if node.name == "detach": + return + if node.name not in self.actual_skip_nodes: + self.actual_skip_nodes[node.name] = {} + self.actual_skip_nodes[node.name][node.id] = reason + def extract_subgraph(self, root): """ return: all nodes in the subgraph, in the order of node ID (also execution) @@ -338,40 +349,46 @@ def anlayze_node(node): func, output_count = self.build_func(node) self.funcs[node.id] = (func, output_count) - def dfs_traverse(root): - for child in root.children: - try: - if self.label and self.label in child.name: - self.sorted_nodes.append(child) + if func is None: + return False, "Failed to build function" + else: + return True, "" - if any(x in child.name for x in self.skip_node_names): - self.actual_skip_nodes.append(child.name) - self.actual_skip_nodes_cnt += 1 - continue + def dfs_traverse(node): + if self.profile_step_label in node.name: + self.profile_step_node_ids += [node.id] - if is_qualified(child): - self.sorted_nodes.append(child) - else: - if skip_op(child): - self.actual_skip_nodes.append(child.name) - self.actual_skip_nodes_cnt += 1 - dfs_traverse(child) - except Exception as e: - print(f"Graph parse error: {e}, node id: {child.id}") - exit(1) + if node.type == NodeType.OPERATOR: + self.sorted_nodes.append(node) + + for child in node.children: + dfs_traverse(child) dfs_traverse(root) self.sorted_nodes = sorted(self.sorted_nodes, key=lambda x: x.id) - for i in range(len(self.sorted_nodes)): - if self.label and self.label in self.sorted_nodes[i].name: - self.operators_count.append(i) - if len(self.operators_count) > 1: - self.sorted_nodes = self.sorted_nodes[ - self.operators_count[0] + 1 : self.operators_count[1] + self.profile_step_node_ids = sorted(self.profile_step_node_ids) + + if len(self.profile_step_node_ids) > 1: + # Only execute the ops in the first step + start_id = self.profile_step_node_ids[0] + end_id = self.profile_step_node_ids[1] + self.sorted_nodes = [ + x for x in self.sorted_nodes if x.id > start_id and x.id < end_id ] + print("#Operators to execute: ", len(self.sorted_nodes)) + picked_nodes = [] for node in self.sorted_nodes: - anlayze_node(node) + if node.name in self.skip_node_names or node.id in self.skip_node_ids: + self.add_skipped_nodes(node, "skipped by user") + continue + + success, msg = anlayze_node(node) + if success: + picked_nodes.append(node) + else: + self.add_skipped_nodes(node, msg) + self.sorted_nodes = picked_nodes # triton kernels are compiled in parallel, need to wait until # all kernels are compiled. @@ -539,28 +556,12 @@ def add_unique_tensor(node_name, node_id, t_id, shape, input, device=-1): output_set.add(self.tensors_mapping[(node.id, t_id, False)]) def allocate_tensors(self): + integer_data_count = 0 for node in self.sorted_nodes: if node.name == "record_param_comms" and ( self.compute_only or self.args.separate ): continue - if is_fbgemm_forward(node): - if self.cpu: - input_args, _ = generate_fbgemm_tensors( - node, - "cpu", - self.args.rows, - self.args.pooling_factor, - self.args.alpha, - ) - else: - input_args, _ = generate_fbgemm_tensors( - node, - self.cuda, - self.args.rows, - self.args.pooling_factor, - self.args.alpha, - ) for idx, (data_type, t_id, shape) in enumerate(get_input_tensors(node)): if self.tensor_with_device: t_id = tuple(list(t_id)[:5]) @@ -568,33 +569,30 @@ def allocate_tensors(self): if ( t_id in self.dependency_permanent and replay_t_id not in self.tensor_registry_permanent.keys() - and ( - node.name == "aten::embedding_bag" - or "fbgemm::split_embedding_codegen_lookup" in node.name - or replay_t_id in self.instantiate - ) + and replay_t_id in self.instantiate ): try: - if is_fbgemm_forward(node): - self.tensor_registry_permanent[replay_t_id] = input_args[ - idx - ] - if "fbgemm::split_embedding_codegen_lookup" in node.name: - self.unchangeable_intermediate_tensors.add(replay_t_id) - else: - if data_type == "Tensor(signed char)": - dtype, rng = TORCH_DTYPES_RNG["signed char"] - else: - dtype, rng = TORCH_DTYPES_RNG[ - data_type.lstrip("Tensor(").rstrip(")") - ] - self.tensor_registry_permanent[replay_t_id] = rng(shape).to( - dtype - ) - if node.name == "aten::embedding_bag": - self.unchangeable_intermediate_tensors.add(replay_t_id) - if node.name == "aten::pin_memory" and idx == 0: - self.cpu_tensor.add(replay_t_id) + + dtype, rng = TORCH_DTYPES_RNG[ + data_type.lstrip("Tensor(").rstrip(")") + ] + + size = 1 + for e in shape: + size *= e + if ( + dtype == torch.int64 + or dtype == torch.int + or dtype == torch.int8 + ): + integer_data_count += size * torch.iinfo(dtype).bits / 8 + self.tensor_registry_permanent[replay_t_id] = rng(shape).to( + dtype + ) + if node.name == "aten::embedding_bag": + self.unchangeable_intermediate_tensors.add(replay_t_id) + if node.name == "aten::pin_memory" and idx == 0: + self.cpu_tensor.add(replay_t_id) except KeyError: if data_type != "Tensor(nullptr (uninitialized))": print("KeyError: ", node.id, t_id, data_type) @@ -620,6 +618,8 @@ def allocate_tensors(self): ][i] = (i * nnz) ###### + print("Integer data count: ", integer_data_count) + def build_func(self, node): if is_fbgemm_forward(node): if self.cpu: @@ -648,9 +648,6 @@ def build_func(self, node): else: func, output_count = build_torchscript_func(node) - if not func: - self.actual_skip_nodes.append(node.name) - self.actual_skip_nodes_cnt += 1 return func, output_count def generate_code(self): @@ -972,7 +969,7 @@ def _generate_run_ops_str(override): if self.cpu: code_str += generate_prefix( - self.label, + self.profile_step_label, skip_nodes_str, self.trace_file, "cpu", @@ -982,7 +979,7 @@ def _generate_run_ops_str(override): ) else: code_str += generate_prefix( - self.label, + self.profile_step_label, skip_nodes_str, self.trace_file, self.cuda, @@ -1089,21 +1086,22 @@ def get_inputs(self, node): inputs.append(self.cuda) else: inputs.append(item) - return inputs + return inputs, "" except Exception as e: - print(f"Inputs error: {e} at node: {node.id}") + return None, f"Inputs error: {e}" def run_op(self, node, iter): if node.name == "record_param_comms" and not self.compute_only: - return + return True, "" if self.debug and iter >= self.numWarmupIters: start_ns = time.time_ns() func, output_count = self.funcs[node.id] - if not func: - return - inputs = self.get_inputs(node) + + inputs, msg = self.get_inputs(node) + if not inputs: + return False, msg # Workaround to eliminate the "strides() called on undefined Tensor" error. if node.name == "aten::convolution_backward": @@ -1147,27 +1145,28 @@ def run_op(self, node, iter): # TODO: Simplify this if not tmp: print(f"Not expect that {node.id} has no output.") - return + return False, "No output" for x in tmp: if isinstance(x, list) and isinstance(x[0], torch.Tensor): outputs.extend(x) elif isinstance(x, torch.Tensor): outputs.append(x) + + if node.name == "aten::repeat_interleave": + current_len = node.input_shapes[0][0] + target_len = node.output_shapes[0][0] + if current_len < target_len: + dtype, _ = TORCH_DTYPES_RNG[ + node.output_types[0].lstrip("Tensor(").rstrip(")") + ] + tmp = ( + torch.zeros(target_len - current_len) + .to(dtype) + .cuda(self.device) + ) + outputs[0] = torch.cat((tmp, outputs[0])) except Exception as e: - print( - f"Run op exception Error: {e}, node id: {node.id}, func: {func}, inputs: {inputs}" - ) - exit(1) - - if node.name == "aten::repeat_interleave": - current_len = node.input_shapes[0][0] - target_len = node.output_shapes[0][0] - if current_len < target_len: - dtype, _ = TORCH_DTYPES_RNG[ - node.output_types[0].lstrip("Tensor(").rstrip(")") - ] - tmp = torch.zeros(target_len - current_len).to(dtype).cuda(self.device) - outputs[0] = torch.cat((tmp, outputs[0])) + return False, f"Execption: {e}" if self.debug and iter >= self.numWarmupIters: after_execution = time.time_ns() @@ -1213,34 +1212,37 @@ def run_op(self, node, iter): ) self.exec_time.append(after_execution - before_execution) + return True, "" + def init_comms(self): pass - def analyze_ops(self): - fused_cnt = 0 - aten_up_cnt = 0 - aten_cnt = 0 - custom_cnt = 0 - comms_cnt = 0 - for op in self.actual_skip_nodes: - if "fused" in op: - fused_cnt += 1 - elif "aten::record_stream" in op or "aten::set_" in op: - aten_up_cnt += 1 - elif "aten::" in op: - aten_cnt += 1 - elif "fbgemm::" in op: - custom_cnt += 1 - elif "record_param_comms" in op: - comms_cnt += 1 - else: - print(op) - print("fused cnt: ", fused_cnt) - print("aten unsupported cnt: ", aten_up_cnt) - print("aten cnt: ", aten_cnt) - print("custom cnt: ", custom_cnt) - print("comms cnt: ", comms_cnt) - print("skipped ops: ", self.actual_skip_nodes) + def remove_op_with_runtime_error(self): + picked_node = [] + for node in self.sorted_nodes: + success, msg = self.run_op(node, iter) + if success: + picked_node.append(node) + continue + + for data_type, t_id, shape in get_output_tensors(node): + if self.tensor_with_device: + t_id = tuple(list(t_id)[:5]) + dtype, rng = TORCH_DTYPES_RNG[data_type.lstrip("Tensor(").rstrip(")")] + replay_t_id = self.tensors_mapping[(node.id, t_id, False)] + t = rng(shape).to(dtype) + if self.tensor_with_device: + if self.tensor_device[replay_t_id] != "cpu" and not self.cpu: + t.cuda(self.tensor_device[replay_t_id]) + else: + if not self.cpu: + t.cuda(self.device) + + self.tensor_registry[replay_t_id] = t + + self.add_skipped_nodes(node, msg) + + self.sorted_nodes = picked_node def preprocess_graph(self): if not self.compute_only and not self.generator: @@ -1265,8 +1267,6 @@ def preprocess_graph(self): self.extract_subgraph(root) - # self.analyze_ops() - # self.analyze_subgraph(root) self.analyze_tensors() @@ -1284,6 +1284,10 @@ def preprocess_graph(self): else: self.allocate_tensors() self.reset_registry() + self.remove_op_with_runtime_error() + + with open("failed_nodes.json", "w") as outfile: + json.dump(self.actual_skip_nodes, outfile, indent=4) def benchTime(self): # A dictionary to save the benchmark result. @@ -1293,7 +1297,7 @@ def benchTime(self): self.preprocess_graph() if self.generator: return - print("Start to execution: ") + print("Start to execute: ") time.sleep(2) total_time = 0.0 @@ -1431,13 +1435,12 @@ def benchTime(self): print("Replay time per iteration: {:.2f} ms".format(total_time / self.numIters)) + n_skipped_nodes = 0 + for node_name in self.actual_skip_nodes: + n_skipped_nodes += len(self.actual_skip_nodes[node_name]) + print( - "Operator coverage: {} / {} = {}".format( - len(self.sorted_nodes), - len(self.sorted_nodes) + self.actual_skip_nodes_cnt, - len(self.sorted_nodes) - / (len(self.sorted_nodes) + self.actual_skip_nodes_cnt), - ) + f"Operator coverage: {len(self.sorted_nodes)} / {len(self.sorted_nodes) + n_skipped_nodes} = {len(self.sorted_nodes) / (len(self.sorted_nodes) + n_skipped_nodes)}" ) end_time = datetime.now()