Skip to content

Commit

Permalink
Clean up dead code and recyle tensor when it is out of scope. (#137)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #137

Clean up dead code:

    removed two not used functions: analyze_ops and analyze_subgraph
    cleaned up other code related to these two functions.

Recycled tensor when it is no longer live: recorded the last op id that refers a tensor id, when replay passing this last op id, recycled the tensor.

With this fix, CMF V0 model can now be replayed. It was out of memory before.

Differential Revision: D59937145
  • Loading branch information
shengfukevin authored and facebook-github-bot committed Jul 23, 2024
1 parent 7b19f58 commit 4aa0326
Showing 1 changed file with 42 additions and 144 deletions.
186 changes: 42 additions & 144 deletions et_replay/tools/et_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import time
from collections import defaultdict
from datetime import datetime
from functools import reduce

import numpy as np
import torch
Expand All @@ -23,20 +22,17 @@
generate_suffix,
get_input_tensors,
get_output_tensors,
has_backward_parent,
is_backward_aten,
is_fbgemm_backward,
is_fbgemm_forward,
is_fbgemm_forward_unweighted,
is_qualified,
is_tensor,
is_tensor_list,
skip_op,
TORCH_DTYPES_BYTES,
TORCH_DTYPES_RNG,
TORCH_DTYPES_RNG_str,
)
from et_replay.execution_trace import ExecutionTrace, NodeType
from et_replay.execution_trace import ExecutionTrace
from et_replay.utils import trace_handler
from param_bench.train.compute.python.lib import pytorch as lib_pytorch
from param_bench.train.compute.python.lib.init_helper import load_modules
Expand All @@ -45,7 +41,7 @@
from torch._inductor.codecache import TritonFuture

# grid and split_scan_grid are dynamically loaded
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid # noqa
from torch.profiler import ExecutionTraceObserver


Expand Down Expand Up @@ -147,13 +143,6 @@ def __init__(self):
# Store the backward fbgemm ops generated in the forward.
self.fbgemm_backward_ops = []

# Dict that stores the input and output tensors of an operator. This is used to detect the
# tensors that appear among the child operator and can not be observed at parent-level.
self.top_tensors = {}
# Additional tensors we allocate since replay at parent-level.
self.additional_tensors = set()
self.additional_tensors_size = 0

# Debug use, record the nodes we skip.
self.actual_skip_nodes = []
self.actual_skip_nodes_cnt = 0
Expand All @@ -164,6 +153,8 @@ def __init__(self):
# the operators.
self.tensor_device = {}

self.replay_id_to_last_node_id_map = {}

# Unrecognized nodes that are neither operators nor predefined label nodes.
self.exceptional_nodes = set()

Expand Down Expand Up @@ -312,16 +303,6 @@ def extract_subgraph(self, root):
"""

def anlayze_node(node):
self.top_tensors[node] = set()
for _, t_id, _ in get_input_tensors(node):
if self.tensor_with_device:
t_id = tuple(list(t_id)[:5])
self.top_tensors[node].add(t_id)
for _, t_id, _ in get_output_tensors(node):
if self.tensor_with_device:
t_id = tuple(list(t_id)[:5])
self.top_tensors[node].add(t_id)

for _, t_id, _ in get_input_tensors(node):
if self.tensor_with_device:
t_id = tuple(list(t_id)[:5])
Expand Down Expand Up @@ -392,95 +373,43 @@ def has_parallel_parent(node):

assert len(self.parallel_nodes_ids) == len(set(self.parallel_nodes_ids))

def analyze_subgraph(self, root):
def bfs_traverse(node):
for child in node.children:
if any(x in child.name for x in self.skip_node_names):
continue

if is_backward_aten(child) or has_backward_parent(child):
continue
else:
if (
child not in self.sorted_nodes
and child.type == NodeType.OPERATOR
):
node = child.parent
while node and node not in self.sorted_nodes:
node = node.parent
if not node:
self.exceptional_nodes.add(child)
continue
for data_type, t_id, shape in get_output_tensors(child):
if self.tensor_with_device:
t_id = tuple(list(t_id)[:5])
if (
t_id not in self.top_tensors[node]
and t_id in self.dependency_permanent
and t_id not in self.additional_tensors
):
self.additional_tensors.add(t_id)
if shape:
self.additional_tensors_size += (
reduce(lambda x, y: x * y, shape)
* TORCH_DTYPES_BYTES[
data_type.lstrip("Tensor(").rstrip(")")
]
)
bfs_traverse(child)

bfs_traverse(root)
print(
f"Additional allocated {len(self.additional_tensors)} tensors with total size of {self.additional_tensors_size/1024/1024}MB"
)

def analyze_tensors(self):
def add_unique_tensor(node_name, node_id, t_id, shape, input, device=-1):
# If we did not see this tensor before, add it as a unique tensor.
if t_id not in self.original_unique_tensors:
self.original_unique_tensors.add(t_id)
self.replay_unique_tensor_num += 1
self.tensors_mapping[(node_id, t_id, input)] = (
self.replay_unique_tensor_num
)
self.replay_tensors_shapes[
self.tensors_mapping[(node_id, t_id, input)]
] = shape
self.tensor_shapes[t_id].add(
(self.tensors_mapping[(node_id, t_id, input)], tuple(shape))
)
replay_t_id = self.replay_unique_tensor_num
self.tensors_mapping[(node_id, t_id, input)] = replay_t_id
self.replay_tensors_shapes[replay_t_id] = shape
self.tensor_shapes[t_id].add((replay_t_id, tuple(shape)))
self.replay_id_to_last_node_id_map[replay_t_id] = node_id
if self.tensor_with_device:
self.tensor_device[self.tensors_mapping[(node_id, t_id, input)]] = (
device
)
self.tensor_device[replay_t_id] = device
if node_name == "aten::to":
self.special_tensors.add(
self.tensors_mapping[(node_id, t_id, input)]
)
self.special_tensors.add(replay_t_id)
return

# If we saw this tensor before but with a different shape, add it as a unique tensor.
for relay_t_id, pre_shape in self.tensor_shapes[t_id]:
for replay_t_id, pre_shape in self.tensor_shapes[t_id]:
if tuple(shape) == pre_shape:
self.tensors_mapping[(node_id, t_id, input)] = relay_t_id
self.tensors_mapping[(node_id, t_id, input)] = replay_t_id
if self.replay_id_to_last_node_id_map[replay_t_id] < node_id:
self.replay_id_to_last_node_id_map[replay_t_id] = node_id
if node_name == "aten::to":
self.special_tensors.add(relay_t_id)
self.special_tensors.add(replay_t_id)
return

self.replay_unique_tensor_num += 1
self.tensors_mapping[(node_id, t_id, input)] = self.replay_unique_tensor_num
self.replay_tensors_shapes[self.tensors_mapping[(node_id, t_id, input)]] = (
shape
)
self.tensor_shapes[t_id].add(
(self.tensors_mapping[(node_id, t_id, input)], tuple(shape))
)
replay_t_id = self.replay_unique_tensor_num
self.replay_tensors_shapes[replay_t_id] = shape
self.tensor_shapes[t_id].add((replay_t_id, tuple(shape)))
self.replay_id_to_last_node_id_map[replay_t_id] = node_id
if self.tensor_with_device:
self.tensor_device[self.tensors_mapping[(node_id, t_id, input)]] = (
device
)
self.tensor_device[replay_t_id] = device
if node_name == "aten::to":
self.special_tensors.add(self.tensors_mapping[(node_id, t_id, input)])
self.special_tensors.add(replay_t_id)

for node in self.sorted_nodes:
if node.name == "record_param_comms" and (
Expand Down Expand Up @@ -1162,30 +1091,32 @@ def run_op(self, node, iter):
if self.debug and iter >= self.numWarmupIters:
after_execution = time.time_ns()

for (_, t_id, _), output in zip(get_output_tensors(node), outputs):
for _, t_id, _ in get_input_tensors(node):
if self.tensor_with_device:
t_id = tuple(list(t_id)[:5])
# if output.isnan().any():
# print(
# node.id,
# t_id,
# output,
# inputs,
# )
replay_t_id = self.tensors_mapping[(node.id, t_id, True)]
if (
t_id in self.dependency_permanent
and self.tensors_mapping[(node.id, t_id, False)]
not in self.unchangeable_intermediate_tensors
node.id >= self.replay_id_to_last_node_id_map[replay_t_id]
and replay_t_id not in self.instantiate
):
del self.tensor_registry[replay_t_id]

for (_, t_id, _), output in zip(get_output_tensors(node), outputs):
if self.tensor_with_device:
t_id = tuple(list(t_id)[:5])

if t_id in self.dependency_permanent:
replay_t_id = self.tensors_mapping[(node.id, t_id, False)]
if (
self.tensors_mapping[(node.id, t_id, False)]
not in self.instantiate
# and self.tensors_mapping[(node.id, t_id, False)]
# not in self.tensor_registry
replay_t_id not in self.unchangeable_intermediate_tensors
and replay_t_id not in self.instantiate
):
self.tensor_registry[
self.tensors_mapping[(node.id, t_id, False)]
] = output
if node.id < self.replay_id_to_last_node_id_map[replay_t_id]:
self.tensor_registry[replay_t_id] = output
else:
del output
else:
del output

if self.profile_memory:
self.op_allocated_mem[node] = (
Expand All @@ -1206,32 +1137,6 @@ def run_op(self, node, iter):
def init_comms(self):
pass

def analyze_ops(self):
fused_cnt = 0
aten_up_cnt = 0
aten_cnt = 0
custom_cnt = 0
comms_cnt = 0
for op in self.actual_skip_nodes:
if "fused" in op:
fused_cnt += 1
elif "aten::record_stream" in op or "aten::set_" in op:
aten_up_cnt += 1
elif "aten::" in op:
aten_cnt += 1
elif "fbgemm::" in op:
custom_cnt += 1
elif "record_param_comms" in op:
comms_cnt += 1
else:
print(op)
print("fused cnt: ", fused_cnt)
print("aten unsupported cnt: ", aten_up_cnt)
print("aten cnt: ", aten_cnt)
print("custom cnt: ", custom_cnt)
print("comms cnt: ", comms_cnt)
print("skipped ops: ", self.actual_skip_nodes)

def preprocess_graph(self):
if not self.compute_only and not self.generator:
self.init_comms()
Expand All @@ -1255,10 +1160,6 @@ def preprocess_graph(self):

self.extract_subgraph(root)

# self.analyze_ops()

# self.analyze_subgraph(root)

self.analyze_tensors()

tensor_with_multiple_shape_count = 0
Expand Down Expand Up @@ -1359,8 +1260,6 @@ def benchTime(self):
torch.cuda.synchronize(self.device)
if iter >= self.numWarmupIters:
total_time += event_1.elapsed_time(event_2)
# Comment out this for now since it will introduce additional cudaMalloc.
# self.reset_registry()
prof.step()
benchmark_result["execution finished"] = True
print("Execution finished!")
Expand Down Expand Up @@ -1397,7 +1296,6 @@ def benchTime(self):
torch.cuda.synchronize(self.device)
if iter >= self.numWarmupIters:
total_time += event_1.elapsed_time(event_2)
# self.reset_registry()
benchmark_result["execution finished"] = True
print("Execution finished!")

Expand Down

0 comments on commit 4aa0326

Please sign in to comment.