Skip to content

Commit

Permalink
[torch][ao] Add customizable loss function to NodeAccuracySummary (py…
Browse files Browse the repository at this point in the history
…torch#136282)

Summary:
Add a customizable loss function callback to NodeAccuracySummary to
allow users to pass in their own loss function.

Also, fix some type errors and propagate better exception messages when
unexpected tensor comparisons occur. Finally, enhance the robustness of
`generate_numeric_debug_handle` in the case where it is called multiple
times on the same model, by avoiding reuse of the same IDs.

Test Plan: Added a test for this case in `test_numeric_debugger`.

Reviewed By: jerryzh168

Differential Revision: D62898297

Pull Request resolved: pytorch#136282
Approved by: https://github.com/jerryzh168
  • Loading branch information
dulinriley authored and pytorchmergebot committed Sep 20, 2024
1 parent 687e5cf commit f3c54cc
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 12 deletions.
54 changes: 52 additions & 2 deletions test/quantization/pt2e/test_numeric_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef, TestCase


def _extract_debug_handles(model) -> Dict[torch.fx.Node, int]:
debug_handle_map: Dict[torch.fx.Node, int] = {}
def _extract_debug_handles(model) -> Dict[str, int]:
debug_handle_map: Dict[str, int] = {}

for node in model.graph.nodes:
if (
Expand Down Expand Up @@ -187,3 +187,53 @@ def test_extract_results_from_loggers(self):
for node_summary in comparison_results.values():
if len(node_summary.results) > 0:
self.assertGreaterEqual(node_summary.results[0].sqnr, 35)

def test_added_node_gets_unique_id(self) -> None:
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
m = capture_pre_autograd_graph(m, example_inputs)
assert isinstance(m, torch.fx.GraphModule)
generate_numeric_debug_handle(m)
ref_handles = _extract_debug_handles(m)
ref_counter = Counter(ref_handles.values())
for k, v in ref_counter.items():
self.assertEqual(
v,
1,
msg=f"For handle {k}, there were {v} nodes with that handle, but expected only 1",
)

# Now that we have unique ids, add a new node into the graph and re-generate
# to make sure that the new node gets a unique id.
last_node = next(iter(reversed(m.graph.nodes)))
with m.graph.inserting_before(last_node):
arg = last_node.args[0]
self.assertIsInstance(arg, tuple)
arg = arg[0]
# Add a function that only requires a single tensor input.
n = m.graph.call_function(torch.ops.aten.relu.default, args=(arg,))
arg.replace_all_uses_with(n, lambda x: x != n)
m.recompile()

# Regenerate handles, make sure only the new relu node has a new id, and
# it doesn't clash with any of the existing ids.
generate_numeric_debug_handle(m)
handles_after_modification = _extract_debug_handles(m)
handles_counter = Counter(handles_after_modification.values())
for name, handle in ref_handles.items():
self.assertIn(name, handles_after_modification)
# Check that handle was unchanged.
self.assertEqual(handles_after_modification[name], handle)
# Check that total count was unchanged.
ref_count = ref_counter[handle]
after_count = handles_counter[handle]
self.assertEqual(
after_count,
ref_count,
msg=f"For handle {handle}, there were {after_count} nodes with that handle, but expected only {ref_count}",
)

# Check for relu specifically. Avoid hardcoding the handle id since it
# may change with future node ordering changes.
self.assertNotEqual(handles_after_modification["relu_default"], 0)
self.assertEqual(handles_counter[handles_after_modification["relu_default"]], 1)
53 changes: 43 additions & 10 deletions torch/ao/quantization/pt2e/_numeric_debugger.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import logging
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple
from typing import Callable, Dict, List, Optional, Sequence, Tuple

import torch
from torch.ao.ns.fx.utils import compute_sqnr
Expand All @@ -19,7 +19,16 @@ def generate_numeric_debug_handle(graph_module: GraphModule) -> None:
"""Attach numeric_debug_handle_id for all nodes in the model except for placeholder node
The graph nodes of input model is modified inplace.
"""
unique_id = 0
unique_id = -1
# Find the max ID that exists in the graph first, in case part of the graph
# has already been annotated. This way we guarantee there are no duplicate
# handle IDs.
for node in graph_module.graph.nodes:
unique_id = max(
unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, -1)
)
unique_id += 1

for node in graph_module.graph.nodes:
if node.op in ["output", "placeholder"]:
continue
Expand Down Expand Up @@ -134,6 +143,17 @@ def sqnr(self) -> torch.Tensor:
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
)

def loss(
self, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
) -> torch.Tensor:
if self.actual.shape != self.ref.shape:
raise ValueError(
f"Cannot compare tensors with different shapes: {self.actual.shape} vs {self.ref.shape}"
)
return loss_function(
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
)

def __repr__(self) -> str:
# Don't include the tensors themselves as they are quite large to print
# out.
Expand All @@ -149,6 +169,10 @@ def __post_init__(self) -> None:

if not isinstance(self.ref, torch.Tensor):
raise ValueError(f"`self.ref` value must be a Tensor, got: {self.ref}")
if self.actual.shape != self.ref.shape:
raise ValueError(
f"Cannot compare tensors with different shapes: ref={self.ref.shape} vs actual={self.actual.shape}"
)


@dataclass(frozen=True)
Expand Down Expand Up @@ -197,8 +221,8 @@ def extract_results_from_loggers(


def compare_results(
ref_results: Dict[int, Tuple[str, object, List[torch.Tensor]]],
actual_results: Dict[int, Tuple[str, object, List[torch.Tensor]]],
ref_results: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]],
actual_results: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]],
) -> Dict[int, NodeAccuracySummary]:
"""Given two dict mapping from `debug_handle_id` (int) to list of tensors
return a map from `debug_handle_id` to `NodeAccuracySummary` that contains
Expand All @@ -220,16 +244,25 @@ def compare_results(
)
continue
actual_name, actual_stack, actual_stats = actual_results[debug_handle]
try:
results = [
QuantizationComparisonResult(actual=a, ref=b)
for a, b in zip(actual_stats, ref_stats)
]
except Exception as e:
# Add extra information for an exception from QuantizationComparisonResult
# if the shapes didn't match, to include the handle and the node names.
raise ValueError(
f"For numeric_debug_handle={debug_handle} from ref node {ref_name} and actual node {actual_name}"
) from e

comparisons[debug_handle] = NodeAccuracySummary(
handle=debug_handle,
actual_node_name=actual_name,
actual_node_name=actual_name or "",
actual_module_stack=_module_stack_to_str(actual_stack),
ref_node_name=ref_name,
ref_node_name=ref_name or "",
ref_module_stack=_module_stack_to_str(ref_stack),
results=[
QuantizationComparisonResult(actual=a, ref=b)
for a, b in zip(actual_stats, ref_stats)
],
results=results,
)

return comparisons

0 comments on commit f3c54cc

Please sign in to comment.