Skip to content

Commit

Permalink
dbr quant: enable reference module support for torch.qint32 (pytorch#…
Browse files Browse the repository at this point in the history
…73493)

Summary:
Pull Request resolved: pytorch#73493

This PR enables basic support for reference modules in DBR quant.
For now, the support is limited to:
1. modules that have reference versions defined only (no functions)
2. torch.qint32 dtype only

Currently, the reference module logic is enabled whenever dtype is
torch.qint32. This is done because this is needed the earliest for
the first use case. A future PR will support more dtypes and also
add the `is_reference` flag to the API.

Test Plan:
```
python test/test_quantization.py TestQuantizeDBR.test_conv_int32_reference_model
```

Reviewed By: jerryzh168

Differential Revision: D34520759

Pulled By: vkuzo

fbshipit-source-id: 363db715315c5c7c20962a1818330ce288948778
(cherry picked from commit 6ccdfe2)
  • Loading branch information
vkuzo authored and pytorchmergebot committed Mar 4, 2022
1 parent 5787a36 commit 727debb
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 25 deletions.
20 changes: 20 additions & 0 deletions test/quantization/dbr/test_quantize_dbr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from torch.quantization import (
ObserverBase,
FakeQuantizeBase,
QConfig,
MinMaxObserver,
)
from torch.quantization.quantize_fx import (
prepare_fx,
Expand Down Expand Up @@ -1341,6 +1343,24 @@ def test_jit_tracing_removes_aliases(self):
):
FileCheck().check_count("aten::alias", 0, exactly=True).run(graph)

def test_conv_int32_reference_model(self):
m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
int32_obs_ctr = MinMaxObserver.with_args(dtype=torch.qint32)
int32_qconfig = QConfig(weight=int32_obs_ctr, activation=int32_obs_ctr)
qconfig_dict = {'': int32_qconfig}
mp = _quantize_dbr.prepare(m, qconfig_dict, (torch.randn(1, 1, 1, 1),))
mp(torch.randn(1, 1, 1, 1))
mq = _quantize_dbr.convert(mp)
res = mq(torch.randn(1, 1, 1, 1))
mqt = torch.jit.trace(mq, (torch.randn(1, 1, 1, 1),))
# verify the right ops are present:
# x0 -> quant -> (dequant -> conv_ref -> quant) -> dequant -> x1
FileCheck()\
.check_count("aten::quantize_per_tensor", 2, exactly=True)\
.run(mqt.graph)
FileCheck()\
.check_count("aten::dequantize", 2, exactly=True)\
.run(mqt.graph)

@skipIfNoFBGEMM
class TestQuantizeDBRMultipleOps(QuantizeDBRTestCase):
Expand Down
8 changes: 4 additions & 4 deletions torch/ao/quantization/_dbr/auto_trace_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ def _maybe_update_args_with_quants(self, args, arg_quant_infos, target):
new_first_arg.append(args[0][idx])
else:
# create a quant node
scale, zp = input_arg_quant_info
scale, zp, dtype = input_arg_quant_info
quant = super().create_node(
'call_function', torch.quantize_per_tensor,
(args[0][idx], scale.item(), zp.item(), torch.quint8), {}, None, None)
(args[0][idx], scale.item(), zp.item(), dtype), {}, None, None)
new_first_arg.append(quant)
new_args = [new_first_arg, *args[1:]]
elif target == torch.cat:
Expand All @@ -61,10 +61,10 @@ def _maybe_update_args_with_quants(self, args, arg_quant_infos, target):
new_args.append(args[idx])
else:
# create a quant node
scale, zp = input_arg_quant_info
scale, zp, dtype = input_arg_quant_info
quant = super().create_node(
'call_function', torch.quantize_per_tensor,
(args[idx], scale.item(), zp.item(), torch.quint8), {}, None, None)
(args[idx], scale.item(), zp.item(), dtype), {}, None, None)
new_args.append(quant)
args = tuple(new_args)
return args
Expand Down
8 changes: 8 additions & 0 deletions torch/ao/quantization/_dbr/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch.ao.quantization.quantization_mappings import (
DEFAULT_STATIC_QUANT_MODULE_MAPPINGS,
DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS,
DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS,
)

import operator
Expand Down Expand Up @@ -67,6 +68,10 @@
set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys())
module_types_supported_by_quantization |= \
set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.values())
module_types_supported_by_quantization |= \
set(DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS.keys())
module_types_supported_by_quantization |= \
set(DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS.values())
module_types_supported_by_quantization |= set([
# these are quantizeable modules which do not need swaps
nn.ReLU,
Expand Down Expand Up @@ -144,6 +149,9 @@
for a, b in DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.items():
a_related_to_b.add((a, b))
a_related_to_b.add((b, a))
for a, b in DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS.items():
a_related_to_b.add((a, b))
a_related_to_b.add((b, a))
for a, b in fp32_to_int8_fun_mapping.items():
a_related_to_b.add((a, b))
a_related_to_b.add((b, a))
Expand Down
6 changes: 3 additions & 3 deletions torch/ao/quantization/_dbr/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ def attach_scale_zp_values_to_model(
if hasattr(module, '_auto_quant_state'):
qstate: AutoQuantizationState = module._auto_quant_state # type: ignore[assignment]
for tensor_id, observer in qstate.tensor_id_to_observer.items():
activation_int8_quantized = \
observer.dtype in [torch.quint8, torch.qint8]
if activation_int8_quantized:
activation_int8_or_int32_quantized = \
observer.dtype in [torch.quint8, torch.qint8, torch.qint32]
if activation_int8_or_int32_quantized:
scale, zp = observer.calculate_qparams()
# tensor_id_to_observer is a ModuleDict which has to have string keys
# tensor_id_to_scale_zp is a normal dict which can have int keys
Expand Down
15 changes: 15 additions & 0 deletions torch/ao/quantization/_dbr/module_swap_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
from torch.nn.intrinsic import _FusedModule
from ..utils import (
activation_is_int8_quantized,
activation_is_int32_quantized,
op_is_int8_dynamically_quantized,
)
from torch.ao.quantization import swap_module
from torch.ao.quantization.quantization_mappings import (
DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS,
)

def _swap_child_modules(
module: torch.nn.Module,
Expand Down Expand Up @@ -42,6 +46,7 @@ def _swap_child_modules(
continue
activation_int8_quantized = activation_is_int8_quantized(qconfig)
op_int8_dynamically_quantized = op_is_int8_dynamically_quantized(qconfig)
activation_int32_quantized = activation_is_int32_quantized(qconfig)

# Get the output observer from qstate and attach it to the module,
# to match the API for Eager mode module swaps
Expand All @@ -58,6 +63,16 @@ def _swap_child_modules(
if not type(mod) in dynamic_mappings:
continue
reassign[local_fqn] = swap_module(mod, dynamic_mappings, {})
elif activation_int32_quantized:
# For now, only apply reference logic to modules quantized to
# int32. Do it automatically.
# TODO(future PR): extend this logic to more dtypes, and add
# the is_reference API flag instead of doing this automatically.
# Note: swap modules only does the swap if the mapping for this
# module exists.
reassign[local_fqn] = swap_module(
mod, DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS, {})

# TODO(future PR): add support for other dtypes

for key, value in reassign.items():
Expand Down
48 changes: 36 additions & 12 deletions torch/ao/quantization/_dbr/quantization_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,15 @@
get_seen_q_op_info_of_end_of_fusion,
)

from torch.ao.quantization.utils import (
activation_is_int32_quantized,
)

OpConvertInfo = Tuple[
# quantized equivalent of original op (None means keep original)
Optional[Callable],
# arg_quant_infos, each element is (scale, zp) for quantized and None otherwise
List[Optional[Tuple[float, int]]],
# arg_quant_infos, each element is (scale, zp, dtype) for quantized and None otherwise
List[Optional[Tuple[float, int, torch.dtype]]],
# arg_dequant_infos, each element is True if this arg needs a dequant
List[bool],
# packed param name, if the op has a packed param
Expand Down Expand Up @@ -454,9 +458,11 @@ def op_convert_before_hook(
quant_info = arg_quant_infos[tensor_arg_idx]
dequant_info = arg_dequant_infos[tensor_arg_idx]
if quant_info is not None:
scale, zp = quant_info
arg = torch.quantize_per_tensor(arg, scale, zp, torch.quint8)
elif dequant_info is True:
scale, zp, dtype = quant_info
arg = torch.quantize_per_tensor(arg, scale, zp, dtype)
if dequant_info is True:
# Note: both quant and dequant paths are taken for
# reference ops.
arg = arg.dequantize()
new_first_arg.append(arg)
tensor_arg_idx += 1
Expand All @@ -470,9 +476,11 @@ def op_convert_before_hook(
quant_info = arg_quant_infos[tensor_arg_idx]
dequant_info = arg_dequant_infos[tensor_arg_idx]
if quant_info is not None:
scale, zp = quant_info
arg = torch.quantize_per_tensor(arg, scale, zp, torch.quint8)
elif dequant_info is True:
scale, zp, dtype = quant_info
arg = torch.quantize_per_tensor(arg, scale, zp, dtype)
if dequant_info is True:
# Note: both quant and dequant paths are taken for
# reference ops.
arg = arg.dequantize()
new_args.append(arg)
tensor_arg_idx += 1
Expand Down Expand Up @@ -518,10 +526,22 @@ def op_convert_after_hook(
global_op_idx: List[int],
) -> Any:
"""
This function is called aftern an op call in a converted model.
TODO: add dequant, if needed
This function is called after an op call in a converted model.
"""
# TODO(future PR): improve performance by moving this out of the
# path of non-reference ops
seen_q_op_info = self._get_cur_seen_q_op_info()

if seen_q_op_info.is_reference_op_at_inference:
# given the current reference module design,
# we need to quantize to the target dtype
output_tensor_info = seen_q_op_info.output_tensor_infos[0]
tensor_id, inf_dtype = \
output_tensor_info.id, output_tensor_info.inf_dtype
scale, zp = self.tensor_id_to_scale_zp[tensor_id]
output = torch.quantize_per_tensor(
output, scale, zp, inf_dtype)

if self.log_op_outputs:
output_clone = clone_detach_tensor_without_dispatch(output)
seen_q_op_info = self._get_cur_seen_q_op_info()
Expand Down Expand Up @@ -795,11 +815,15 @@ def _first_call_op_prepare_before_hook_create_subgraphs(
op_type_is_module = isinstance(op, torch.nn.Module)
op_type = type(op) if op_type_is_module else op # type: ignore[assignment]
qconfig = get_cur_qconfig(self.qconfig_dict, fqn, op_type)
# TODO(future PR): use API flag instead of qconfig for is_reference
is_reference_op_at_inference = \
qconfig is not None and activation_is_int32_quantized(qconfig)
self.idx_to_seen_q_op_infos[self.idx] = SeenQOpInfo(
self.idx, op_type, op_type_is_module, fqn, arg_tensor_infos, [],
packable_tensor_idx_to_name, packable_nontensor_idx_to_arg,
packable_tensor_kwarg_name_to_name,
op_packing_only_uses_module_attributes, qconfig, None)
op_packing_only_uses_module_attributes, qconfig, None,
is_reference_op_at_inference)

return args, kwargs

Expand Down
20 changes: 15 additions & 5 deletions torch/ao/quantization/_dbr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ class SeenQOpInfo:
qconfig: QConfigAny
# fusion_info for the op, is None if no fusion is found
fusion_info: Optional[FusionInfo]
# True if this op is a reference op during inference
is_reference_op_at_inference: bool

def __repr__(self) -> str:
s = f"(type): {self.type}\n"
Expand Down Expand Up @@ -648,7 +650,7 @@ def clone_detach_tensor_without_dispatch(x: torch.Tensor) -> torch.Tensor:
def get_input_args_quant_dequant_info(
seen_q_op_info: SeenQOpInfo,
tensor_id_to_scale_zp: Dict[int, Tuple[torch.Tensor, torch.Tensor]],
) -> Tuple[List[Optional[Tuple[float, int]]], List[bool], bool]:
) -> Tuple[List[Optional[Tuple[float, int, torch.dtype]]], List[bool], bool]:
"""
Returns a list of information about the tensor inputs to the current op.
Expand All @@ -674,7 +676,7 @@ def get_input_args_quant_dequant_info(
# dequants
[False, False]
"""
quant_infos: List[Optional[Tuple[float, int]]] = []
quant_infos: List[Optional[Tuple[float, int, torch.dtype]]] = []
dequant_infos: List[bool] = []

# determine the expected output dtype
Expand All @@ -690,12 +692,20 @@ def get_input_args_quant_dequant_info(
tensor_id = input_arg.id
if input_arg.inf_dtype != output_dtype:
any_arg_quant_or_dequant_needed = True
if output_dtype == torch.quint8:
if output_dtype in (torch.quint8, torch.qint32):
assert tensor_id in tensor_id_to_scale_zp
scale, zp = tensor_id_to_scale_zp[tensor_id]
# TODO: return this to the caller
quant_infos.append((scale, zp,)) # type: ignore[arg-type]
dequant_infos.append(False)
quant_infos.append((scale, zp, output_dtype)) # type: ignore[arg-type]
if output_dtype == torch.qint32:
# For now, we treat all qint32 ops as reference, so
# we add a dequant before the op.
# TODO(future PR): extend this to more dtypes
# TODO(future PR): use is_reference flag instead of
# assuming
dequant_infos.append(True)
else:
dequant_infos.append(False)
else:
quant_infos.append(None)
dequant_infos.append(True)
Expand Down
6 changes: 6 additions & 0 deletions torch/ao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ def activation_is_int8_quantized(qconfig):
"""
return activation_dtype(qconfig) in [torch.quint8, torch.qint8]

def activation_is_int32_quantized(qconfig):
""" Given a qconfig, decide if the activation needs to be
quantized to int32 or not
"""
return activation_dtype(qconfig) == torch.qint32

def weight_is_quantized(qconfig):
""" Given a qconfig, decide if the weight needs to be
quantized or not
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/quantized/_reference/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _quantize_weight(
weight_zero_point: torch.Tensor,
weight_axis: torch.Tensor):
if weight_qscheme == torch.per_tensor_affine:
if weight_dtype in [torch.quint8, torch.qint8]:
if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
weight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, weight_dtype)
elif weight_dtype == torch.float16:
weight = weight.to(weight_dtype)
Expand Down

0 comments on commit 727debb

Please sign in to comment.