Skip to content

Commit

Permalink
Add int4 weight-only embedding QAT (#947)
Browse files Browse the repository at this point in the history
Based on changes in #886 by @TiRune
  • Loading branch information
andrewor14 authored Oct 1, 2024
1 parent 71be315 commit 5a4857e
Show file tree
Hide file tree
Showing 5 changed files with 307 additions and 23 deletions.
23 changes: 23 additions & 0 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ def forward(self, x):
x = self.linear2(x)
return x

class M2(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Embedding(10, 512)

def example_inputs(self):
return (torch.randint(1, 10, (1, 512)),)

def forward(self, x):
return self.embedding(x)


class TestQAT(unittest.TestCase):
SEED = 123
Expand Down Expand Up @@ -669,5 +680,17 @@ def test_composable_qat_quantizer(self):
values_list = getattr(model.linear1, self._MyQATQuantizer.ATTR_NAME)
self.assertEqual(values_list, ["quantizer1", "quantizer2", "quantizer1", "quantizer2"])

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_4w_embedding(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyEmbeddingQATQuantizer
model = M2()
x = model.example_inputs()
out = model(*x)
quantizer = Int4WeightOnlyEmbeddingQATQuantizer()
prepared = quantizer.prepare(model)
prepared_out = prepared(*x)
converted = quantizer.convert(model)
converted_out = converted(*x)

if __name__ == "__main__":
unittest.main()
4 changes: 4 additions & 0 deletions torchao/quantization/prototype/qat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from ._module_swap_api import (
Int8DynActInt4WeightQATLinear,
)
from .embedding import (
Int4WeightOnlyEmbeddingQATQuantizer,
)

__all__ = [
"disable_4w_fake_quant",
Expand All @@ -23,6 +26,7 @@
"int8_dynamic_activation_int4_weight_fake_quantize",
"ComposableQATQuantizer",
"Int4WeightOnlyQATQuantizer",
"Int4WeightOnlyEmbeddingQATQuantizer"
"Int8DynActInt4WeightQATQuantizer",
"Int8DynActInt4WeightQATLinear",
]
36 changes: 13 additions & 23 deletions torchao/quantization/prototype/qat/_module_swap_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,23 @@
_choose_qparams_per_token_asymmetric,
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
_get_qmin_qmax,
)


# TODO: deprecate this flow in favor of the tensor subclass flow under qat/api.py
# This is currently needed for DDP and FSDP1, which are not compatible with the
# subclass flow.
# TODO: make module swap the main flow again, and remove the quantize_ flow
# TODO: rename this file to linear.py

# =========================================================
# | Linear int8 dynamic activations + int4 weight QAT |
# =========================================================


class Int8DynActInt4WeightQATQuantizerModuleSwap(Int8DynActInt4WeightQATQuantizer):
"""
Quantizer for performing QAT on a model, where linear layers have int8
dynamic per token fake quantized activations and int4 fake quantized
grouped per channel weights.
Note: This quantizer is implemented using module swaps and may be
deprecated in the future. Please use `Int8DynActInt4WeightQATQuantizer`
instead if possible.
"""

def prepare(
Expand Down Expand Up @@ -92,7 +92,7 @@ def _convert_qat_linear_8da4w(module: torch.nn.Module):

# Load weights and qparams into quantized linear
n_bit = 4
(qmin, qmax) = child._get_qmin_qmax(n_bit)
(qmin, qmax) = _get_qmin_qmax(n_bit)
(s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize)
from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
Expand Down Expand Up @@ -156,7 +156,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
(act_scales, act_zp) = _choose_qparams_per_token_asymmetric(
x, self.scales_precision, self.zero_points_precision,
)
(act_qmin, act_qmax) = self._get_qmin_qmax(8)
(act_qmin, act_qmax) = _get_qmin_qmax(8)
x_fq = _fake_quantize_per_token(
x, act_scales, act_zp, act_qmin, act_qmax,
)
Expand All @@ -170,7 +170,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)
# TODO: pass zp dtype to `get_group_qparams_symmetric` instead
weight_zp = weight_zp.to(self.zero_points_precision)
(weight_qmin, weight_qmax) = self._get_qmin_qmax(4)
(weight_qmin, weight_qmax) = _get_qmin_qmax(4)
w_fq = _fake_quantize_per_channel_group(
self.weight,
weight_scales,
Expand All @@ -183,12 +183,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
w_fq = self.weight
return F.linear(x_fq, w_fq)

# TODO: move this to common util
def _get_qmin_qmax(self, n_bit: int):
qmin = -(2 ** (n_bit - 1))
qmax = 2 ** (n_bit - 1) - 1
return (qmin, qmax)


def enable_8da4w_fake_quant_module_swap(mod: torch.nn.Module):
"""
Expand All @@ -206,19 +200,15 @@ def disable_8da4w_fake_quant_module_swap(mod: torch.nn.Module):
mod.disable_fake_quant()


# ==================
# | int4wo QAT |
# ==================
# ===================================
# | Linear int4 weight-only QAT |
# ===================================


class Int4WeightOnlyQATQuantizerModuleSwap(Int4WeightOnlyQATQuantizer):
"""
Quantizer for performing QAT on a model, where linear layers have
int4 fake quantized grouped per channel weights.
Note: This quantizer is implemented using module swaps and may be
deprecated in the future. Please use `Int4WeightOnlyQATQuantizer`
instead if possible.
"""

def prepare(
Expand Down
Loading

0 comments on commit 5a4857e

Please sign in to comment.