diff --git a/et_replay/tools/et_replay.py b/et_replay/tools/et_replay.py index a10e8c5f..053500ca 100644 --- a/et_replay/tools/et_replay.py +++ b/et_replay/tools/et_replay.py @@ -7,7 +7,6 @@ import time from collections import defaultdict from datetime import datetime -from functools import reduce import numpy as np import torch @@ -23,8 +22,6 @@ generate_suffix, get_input_tensors, get_output_tensors, - has_backward_parent, - is_backward_aten, is_fbgemm_backward, is_fbgemm_forward, is_fbgemm_forward_unweighted, @@ -32,11 +29,10 @@ is_tensor, is_tensor_list, skip_op, - TORCH_DTYPES_BYTES, TORCH_DTYPES_RNG, TORCH_DTYPES_RNG_str, ) -from et_replay.execution_trace import ExecutionTrace, NodeType +from et_replay.execution_trace import ExecutionTrace from et_replay.utils import trace_handler from param_bench.train.compute.python.lib import pytorch as lib_pytorch from param_bench.train.compute.python.lib.init_helper import load_modules @@ -45,7 +41,7 @@ from torch._inductor.codecache import TritonFuture # grid and split_scan_grid are dynamically loaded -from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid +from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid # noqa from torch.profiler import ExecutionTraceObserver @@ -147,13 +143,6 @@ def __init__(self): # Store the backward fbgemm ops generated in the forward. self.fbgemm_backward_ops = [] - # Dict that stores the input and output tensors of an operator. This is used to detect the - # tensors that appear among the child operator and can not be observed at parent-level. - self.top_tensors = {} - # Additional tensors we allocate since replay at parent-level. - self.additional_tensors = set() - self.additional_tensors_size = 0 - # Debug use, record the nodes we skip. self.actual_skip_nodes = [] self.actual_skip_nodes_cnt = 0 @@ -164,6 +153,8 @@ def __init__(self): # the operators. self.tensor_device = {} + self.replay_id_to_last_node_id_map = {} + # Unrecognized nodes that are neither operators nor predefined label nodes. self.exceptional_nodes = set() @@ -312,16 +303,6 @@ def extract_subgraph(self, root): """ def anlayze_node(node): - self.top_tensors[node] = set() - for _, t_id, _ in get_input_tensors(node): - if self.tensor_with_device: - t_id = tuple(list(t_id)[:5]) - self.top_tensors[node].add(t_id) - for _, t_id, _ in get_output_tensors(node): - if self.tensor_with_device: - t_id = tuple(list(t_id)[:5]) - self.top_tensors[node].add(t_id) - for _, t_id, _ in get_input_tensors(node): if self.tensor_with_device: t_id = tuple(list(t_id)[:5]) @@ -392,95 +373,43 @@ def has_parallel_parent(node): assert len(self.parallel_nodes_ids) == len(set(self.parallel_nodes_ids)) - def analyze_subgraph(self, root): - def bfs_traverse(node): - for child in node.children: - if any(x in child.name for x in self.skip_node_names): - continue - - if is_backward_aten(child) or has_backward_parent(child): - continue - else: - if ( - child not in self.sorted_nodes - and child.type == NodeType.OPERATOR - ): - node = child.parent - while node and node not in self.sorted_nodes: - node = node.parent - if not node: - self.exceptional_nodes.add(child) - continue - for data_type, t_id, shape in get_output_tensors(child): - if self.tensor_with_device: - t_id = tuple(list(t_id)[:5]) - if ( - t_id not in self.top_tensors[node] - and t_id in self.dependency_permanent - and t_id not in self.additional_tensors - ): - self.additional_tensors.add(t_id) - if shape: - self.additional_tensors_size += ( - reduce(lambda x, y: x * y, shape) - * TORCH_DTYPES_BYTES[ - data_type.lstrip("Tensor(").rstrip(")") - ] - ) - bfs_traverse(child) - - bfs_traverse(root) - print( - f"Additional allocated {len(self.additional_tensors)} tensors with total size of {self.additional_tensors_size/1024/1024}MB" - ) - def analyze_tensors(self): def add_unique_tensor(node_name, node_id, t_id, shape, input, device=-1): # If we did not see this tensor before, add it as a unique tensor. if t_id not in self.original_unique_tensors: self.original_unique_tensors.add(t_id) self.replay_unique_tensor_num += 1 - self.tensors_mapping[(node_id, t_id, input)] = ( - self.replay_unique_tensor_num - ) - self.replay_tensors_shapes[ - self.tensors_mapping[(node_id, t_id, input)] - ] = shape - self.tensor_shapes[t_id].add( - (self.tensors_mapping[(node_id, t_id, input)], tuple(shape)) - ) + replay_t_id = self.replay_unique_tensor_num + self.tensors_mapping[(node_id, t_id, input)] = replay_t_id + self.replay_tensors_shapes[replay_t_id] = shape + self.tensor_shapes[t_id].add((replay_t_id, tuple(shape))) + self.replay_id_to_last_node_id_map[replay_t_id] = node_id if self.tensor_with_device: - self.tensor_device[self.tensors_mapping[(node_id, t_id, input)]] = ( - device - ) + self.tensor_device[replay_t_id] = device if node_name == "aten::to": - self.special_tensors.add( - self.tensors_mapping[(node_id, t_id, input)] - ) + self.special_tensors.add(replay_t_id) return # If we saw this tensor before but with a different shape, add it as a unique tensor. - for relay_t_id, pre_shape in self.tensor_shapes[t_id]: + for replay_t_id, pre_shape in self.tensor_shapes[t_id]: if tuple(shape) == pre_shape: - self.tensors_mapping[(node_id, t_id, input)] = relay_t_id + self.tensors_mapping[(node_id, t_id, input)] = replay_t_id + if self.replay_id_to_last_node_id_map[replay_t_id] < node_id: + self.replay_id_to_last_node_id_map[replay_t_id] = node_id if node_name == "aten::to": - self.special_tensors.add(relay_t_id) + self.special_tensors.add(replay_t_id) return self.replay_unique_tensor_num += 1 self.tensors_mapping[(node_id, t_id, input)] = self.replay_unique_tensor_num - self.replay_tensors_shapes[self.tensors_mapping[(node_id, t_id, input)]] = ( - shape - ) - self.tensor_shapes[t_id].add( - (self.tensors_mapping[(node_id, t_id, input)], tuple(shape)) - ) + replay_t_id = self.replay_unique_tensor_num + self.replay_tensors_shapes[replay_t_id] = shape + self.tensor_shapes[t_id].add((replay_t_id, tuple(shape))) + self.replay_id_to_last_node_id_map[replay_t_id] = node_id if self.tensor_with_device: - self.tensor_device[self.tensors_mapping[(node_id, t_id, input)]] = ( - device - ) + self.tensor_device[replay_t_id] = device if node_name == "aten::to": - self.special_tensors.add(self.tensors_mapping[(node_id, t_id, input)]) + self.special_tensors.add(replay_t_id) for node in self.sorted_nodes: if node.name == "record_param_comms" and ( @@ -1162,30 +1091,32 @@ def run_op(self, node, iter): if self.debug and iter >= self.numWarmupIters: after_execution = time.time_ns() - for (_, t_id, _), output in zip(get_output_tensors(node), outputs): + for _, t_id, _ in get_input_tensors(node): if self.tensor_with_device: t_id = tuple(list(t_id)[:5]) - # if output.isnan().any(): - # print( - # node.id, - # t_id, - # output, - # inputs, - # ) + replay_t_id = self.tensors_mapping[(node.id, t_id, True)] if ( - t_id in self.dependency_permanent - and self.tensors_mapping[(node.id, t_id, False)] - not in self.unchangeable_intermediate_tensors + node.id >= self.replay_id_to_last_node_id_map[replay_t_id] + and replay_t_id not in self.instantiate ): + del self.tensor_registry[replay_t_id] + + for (_, t_id, _), output in zip(get_output_tensors(node), outputs): + if self.tensor_with_device: + t_id = tuple(list(t_id)[:5]) + + if t_id in self.dependency_permanent: + replay_t_id = self.tensors_mapping[(node.id, t_id, False)] if ( - self.tensors_mapping[(node.id, t_id, False)] - not in self.instantiate - # and self.tensors_mapping[(node.id, t_id, False)] - # not in self.tensor_registry + replay_t_id not in self.unchangeable_intermediate_tensors + and replay_t_id not in self.instantiate ): - self.tensor_registry[ - self.tensors_mapping[(node.id, t_id, False)] - ] = output + if node.id < self.replay_id_to_last_node_id_map[replay_t_id]: + self.tensor_registry[replay_t_id] = output + else: + del output + else: + del output if self.profile_memory: self.op_allocated_mem[node] = ( @@ -1206,32 +1137,6 @@ def run_op(self, node, iter): def init_comms(self): pass - def analyze_ops(self): - fused_cnt = 0 - aten_up_cnt = 0 - aten_cnt = 0 - custom_cnt = 0 - comms_cnt = 0 - for op in self.actual_skip_nodes: - if "fused" in op: - fused_cnt += 1 - elif "aten::record_stream" in op or "aten::set_" in op: - aten_up_cnt += 1 - elif "aten::" in op: - aten_cnt += 1 - elif "fbgemm::" in op: - custom_cnt += 1 - elif "record_param_comms" in op: - comms_cnt += 1 - else: - print(op) - print("fused cnt: ", fused_cnt) - print("aten unsupported cnt: ", aten_up_cnt) - print("aten cnt: ", aten_cnt) - print("custom cnt: ", custom_cnt) - print("comms cnt: ", comms_cnt) - print("skipped ops: ", self.actual_skip_nodes) - def preprocess_graph(self): if not self.compute_only and not self.generator: self.init_comms() @@ -1255,10 +1160,6 @@ def preprocess_graph(self): self.extract_subgraph(root) - # self.analyze_ops() - - # self.analyze_subgraph(root) - self.analyze_tensors() tensor_with_multiple_shape_count = 0 @@ -1359,8 +1260,6 @@ def benchTime(self): torch.cuda.synchronize(self.device) if iter >= self.numWarmupIters: total_time += event_1.elapsed_time(event_2) - # Comment out this for now since it will introduce additional cudaMalloc. - # self.reset_registry() prof.step() benchmark_result["execution finished"] = True print("Execution finished!") @@ -1397,7 +1296,6 @@ def benchTime(self): torch.cuda.synchronize(self.device) if iter >= self.numWarmupIters: total_time += event_1.elapsed_time(event_2) - # self.reset_registry() benchmark_result["execution finished"] = True print("Execution finished!")