Skip to content

Commit

Permalink
revert changes to patching.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Jan 29, 2025
1 parent 11c7fce commit 9e479b0
Showing 1 changed file with 1 addition and 27 deletions.
28 changes: 1 addition & 27 deletions sharktank/sharktank/utils/patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def _patch(name: str, m: torch.nn.Module):
orig_forward = m.forward

def wrapper(*args, **kwargs):
self.before_forward(name, m, *args, **kwargs)
results = orig_forward(*args, **kwargs)
self.after_forward(name, m, results)
return results
Expand All @@ -35,12 +34,6 @@ def wrapper(*args, **kwargs):
for name, m in module.named_modules():
_patch(name, m)

def before_forward(
self, module_name: str, module: torch.nn.Module, *args, **kwargs
):
"""Called before every patched forward() function with results."""
...

def after_forward(self, module_name: str, module: torch.nn.Module, results):
"""Called after every patched forward() function with results."""
...
Expand All @@ -62,36 +55,17 @@ def __init__(self):
# Map of module_name to last used index for duplicated tensors.
self.duplicate_tensors = {}

def before_forward(self, module_name, module, *args, **kwargs):
for idx, arg in enumerate(args):
if not isinstance(arg, torch.Tensor):
continue
result_tensor = torch.detach(arg).contiguous().to(device="cpu").clone()
name_base = f"{module_name}_input_{idx}"
if name_base in self.tensors:
orig_dup = self.tensors[name_base]
del self.tensors[name_base]
self.duplicate_tensors[name_base] = 0
self.tensors[f"{name_base}#0"] = orig_dup
elif name_base in self.duplicate_tensors:
index = self.duplicate_tensors[name_base] + 1
self.duplicate_tensors[name_base] = index
self.tensors[f"{name_base}#{index}"] = result_tensor
else:
self.tensors[name_base] = result_tensor

def after_forward(self, module_name: str, module: torch.nn.Module, results):
if not isinstance(results, torch.Tensor):
return

result_tensor = torch.detach(results).contiguous().to(device="cpu").clone()

if module_name in self.tensors:
orig_dup = self.tensors[module_name]
del self.tensors[module_name]
self.duplicate_tensors[module_name] = 0
self.tensors[f"{module_name}#0"] = orig_dup
elif module_name in self.duplicate_tensors:
if module_name in self.duplicate_tensors:
index = self.duplicate_tensors[module_name] + 1
self.duplicate_tensors[module_name] = index
self.tensors[f"{module_name}#{index}"] = result_tensor
Expand Down

0 comments on commit 9e479b0

Please sign in to comment.