Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generic fake quantized linear for QAT #1020

Open
wants to merge 7 commits into
base: gh/andrewor14/4/base
Choose a base branch
from
Open
229 changes: 186 additions & 43 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,27 @@
import unittest

import torch
import torch.nn.functional as F
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from torchao.dtypes import (
TensorCoreTiledLayoutType,
)
from torchao.quantization.prototype.qat.api import (
ComposableQATQuantizer,
FakeQuantizeConfig,
QuantizationGranularity,
)
from torchao.quantization.prototype.qat.fake_quantizer import (
FakeQuantizer,
)
from torchao.quantization.prototype.qat.linear import (
FakeQuantizedLinear,
)
from torchao.quantization.prototype.qat.utils import (
_choose_qparams_per_token_asymmetric,
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
_get_qmin_qmax,
_GenericFakeQuantize,
)
from torchao.quantization.quant_api import (
Expand Down Expand Up @@ -92,15 +102,10 @@ def forward(self, x):
class TestQAT(unittest.TestCase):
SEED = 123

def _get_qmin_qmax(self, n_bit: int):
qmin = -(2 ** (n_bit - 1))
qmax = 2 ** (n_bit - 1) - 1
return (qmin, qmax)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_fake_quantize_per_channel_group(self):
n_bit = 4
(qmin, qmax) = self._get_qmin_qmax(n_bit)
(qmin, qmax) = _get_qmin_qmax(n_bit)
group_size = 128

torch.manual_seed(self.SEED)
Expand All @@ -126,7 +131,7 @@ def test_fake_quantize_per_channel_group(self):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_fake_quantize_per_token(self):
(qmin, qmax) = self._get_qmin_qmax(8)
(qmin, qmax) = _get_qmin_qmax(8)

torch.manual_seed(self.SEED)
x = torch.randn(100, 256).requires_grad_()
Expand Down Expand Up @@ -165,11 +170,11 @@ def _set_ptq_weight(
Int4WeightOnlyQATLinear,
)
n_bit = 4
(qmin, qmax) = self._get_qmin_qmax(n_bit)
(qmin, qmax) = _get_qmin_qmax(n_bit)
group_size = qat_linear.weight_fake_quantizer.config.group_size
if isinstance(ptq_linear, Int8DynActInt4WeightLinear):
assert isinstance(qat_linear, Int8DynActInt4WeightQATLinear)
fp32_weight = qat_linear.weight
group_size = qat_linear.groupsize
(s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size)
q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group(
fp32_weight, s, zp, qmin, qmax, torch.int8, group_size,
Expand All @@ -180,7 +185,7 @@ def _set_ptq_weight(
elif isinstance(ptq_linear, WeightOnlyInt4Linear):
assert isinstance(qat_linear, Int4WeightOnlyQATLinear)
(q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor(
qat_linear.weight, n_bit, qat_linear.groupsize,
qat_linear.weight, n_bit, group_size,
)
q_weight = torch.ops.aten._convert_weight_to_int4pack(
q_weight.to("cuda"), qat_linear.inner_k_tiles,
Expand Down Expand Up @@ -218,31 +223,36 @@ def test_qat_8da4w_linear(self):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer

group_size = 16
torch.manual_seed(self.SEED)
m = M()
m2 = copy.deepcopy(m)
subclass_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
module_swap_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
subclass_model = subclass_quantizer.prepare(m)
module_swap_model = module_swap_quantizer.prepare(m2)
qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
ptq_quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size)
qat_model = qat_quantizer.prepare(m)
ptq_model = ptq_quantizer.quantize(m2)

# Compare model values
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
subclass_out = subclass_model(*x)
module_swap_out = module_swap_model(*x2)
torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0)
qat_out = qat_model(*x)
ptq_out = ptq_model(*x2)
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)

# Convert QAT model and compare model values
subclass_model = subclass_quantizer.convert(subclass_model)
module_swap_model = module_swap_quantizer.convert(module_swap_model)
subclass_out = subclass_model(*x)
module_swap_out = module_swap_model(*x2)
torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0)
converted_model = qat_quantizer.convert(qat_model)
converted_out = converted_model(*x)
torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0)

# Compare converted state dict
ptq_state_dict = ptq_model.state_dict()
converted_state_dict = converted_model.state_dict()
self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys())
for k in ptq_state_dict.keys():
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_meta_weights(self):
Expand Down Expand Up @@ -275,9 +285,12 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model = quantizer.prepare(m)
qat_model.apply(disable_8da4w_fake_quant)
self.assertFalse(qat_model.linear1._fake_quant_enabled)
self.assertFalse(qat_model.linear2._fake_quant_enabled)
self.assertFalse(qat_model.sub.linear._fake_quant_enabled)
self.assertFalse(qat_model.linear1.activation_fake_quantizer.enabled)
self.assertFalse(qat_model.linear1.weight_fake_quantizer.enabled)
self.assertFalse(qat_model.linear2.activation_fake_quantizer.enabled)
self.assertFalse(qat_model.linear2.weight_fake_quantizer.enabled)
self.assertFalse(qat_model.sub.linear.activation_fake_quantizer.enabled)
self.assertFalse(qat_model.sub.linear.weight_fake_quantizer.enabled)

# Disabled fake quant is just a normal linear
m2.linear1.weight = torch.nn.Parameter(qat_model.linear1.weight)
Expand All @@ -292,9 +305,12 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):

# Renable fake quant
qat_model.apply(enable_8da4w_fake_quant)
self.assertTrue(qat_model.linear1._fake_quant_enabled)
self.assertTrue(qat_model.linear2._fake_quant_enabled)
self.assertTrue(qat_model.sub.linear._fake_quant_enabled)
self.assertTrue(qat_model.linear1.activation_fake_quantizer.enabled)
self.assertTrue(qat_model.linear1.weight_fake_quantizer.enabled)
self.assertTrue(qat_model.linear2.activation_fake_quantizer.enabled)
self.assertTrue(qat_model.linear2.weight_fake_quantizer.enabled)
self.assertTrue(qat_model.sub.linear.activation_fake_quantizer.enabled)
self.assertTrue(qat_model.sub.linear.weight_fake_quantizer.enabled)

# Fake quant should be applied as normal
quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
Expand Down Expand Up @@ -407,7 +423,7 @@ def test_qat_generic_fake_quantize(self):
the numerics of existing fake quantize ops in Pytorch in both
the forward and the backward passes.
"""
(qmin, qmax) = self._get_qmin_qmax(4)
(qmin, qmax) = _get_qmin_qmax(4)
py_input = torch.randn(16, 64).float().requires_grad_()
py_s = torch.randn(16).float()
py_zp = torch.randint(qmax, size=(16,), dtype=torch.int32)
Expand Down Expand Up @@ -521,7 +537,7 @@ def test_qat_4w_quantizer_gradients(self):
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_qat_4w_quantizer(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATQuantizer
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer

group_size = 32
inner_k_tiles = 8
Expand All @@ -530,29 +546,34 @@ def test_qat_4w_quantizer(self):
torch.manual_seed(self.SEED)
m = M().to(device).to(dtype)
m2 = copy.deepcopy(m)
subclass_quantizer = Int4WeightOnlyQATQuantizer(
qat_quantizer = Int4WeightOnlyQATQuantizer(
groupsize=group_size, inner_k_tiles=inner_k_tiles,
)
module_swap_quantizer = Int4WeightOnlyQATQuantizer(
ptq_quantizer = Int4WeightOnlyQuantizer(
groupsize=group_size, inner_k_tiles=inner_k_tiles,
)
subclass_model = subclass_quantizer.prepare(m)
module_swap_model = module_swap_quantizer.prepare(m2)
qat_model = qat_quantizer.prepare(m)
ptq_model = ptq_quantizer.quantize(m2)

# Compare model values
torch.manual_seed(self.SEED)
x = [i.to(device).to(dtype) for i in m.example_inputs()]
x2 = copy.deepcopy(x)
subclass_out = subclass_model(*x)
module_swap_out = module_swap_model(*x2)
torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0)
qat_out = qat_model(*x)
ptq_out = ptq_model(*x2)
self._assert_close_4w(qat_out, ptq_out)

# Convert QAT model and compare model values
subclass_model = subclass_quantizer.convert(subclass_model)
module_swap_model = module_swap_quantizer.convert(module_swap_model)
subclass_out = subclass_model(*x)
module_swap_out = module_swap_model(*x2)
torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0)
converted_model = qat_quantizer.convert(qat_model)
converted_out = converted_model(*x)
torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0)

# Compare converted state dict
ptq_state_dict = ptq_model.state_dict()
converted_state_dict = converted_model.state_dict()
self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys())
for k in ptq_state_dict.keys():
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)

class _MyQATQuantizer(TwoStepQuantizer):
"""
Expand Down Expand Up @@ -603,5 +624,127 @@ def test_qat_4w_embedding(self):
converted = quantizer.convert(model)
converted_out = converted(*x)

def test_fake_quantize_config(self):
"""
Test initialization and property setting of `FakeQuantizeConfig`.
"""
# basic configs
per_token_config = FakeQuantizeConfig(8, "per_token")
self.assertEqual(per_token_config.bit_width, 8)
self.assertEqual(per_token_config.granularity, QuantizationGranularity.PER_TOKEN)
self.assertIsNone(per_token_config.group_size)
per_channel_config = FakeQuantizeConfig(4, "per_channel")
self.assertEqual(per_channel_config.bit_width, 4)
self.assertEqual(per_channel_config.granularity, QuantizationGranularity.PER_CHANNEL)
self.assertIsNone(per_channel_config.group_size)

# initialize per_group config using only group size
per_group_config = FakeQuantizeConfig(4, group_size=32)
self.assertEqual(per_group_config.bit_width, 4)
self.assertEqual(per_group_config.granularity, QuantizationGranularity.PER_GROUP)
self.assertEqual(per_group_config.group_size, 32)

# set granularity after initialization, should accept str as before
per_group_config.granularity = "per_token"
self.assertEqual(per_token_config.granularity, QuantizationGranularity.PER_TOKEN)

# set group_size after initialization, should also update granularity
per_group_config.group_size = 16
self.assertEqual(per_group_config.granularity, QuantizationGranularity.PER_GROUP)
self.assertEqual(per_group_config.group_size, 16)

# bad config1: no granularity or group size provided
with self.assertRaisesRegex(ValueError, "group_size or granularity must be set"):
FakeQuantizeConfig(8)

# bad config2: 'per_group' but no group size
with self.assertRaisesRegex(ValueError, "no group_size was set"):
FakeQuantizeConfig(8, "per_group")

# bad config3: group size was set but granularity was not 'per_group'
with self.assertRaisesRegex(ValueError, "group_size was set"):
FakeQuantizeConfig(8, "per_token", group_size=16)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_fake_quantized_linear_8da4w(self):
"""
Test that we can express int8 dynamic activations + int4 weights with `FakeQuantizedLinear`.
"""
group_size = 128
torch.manual_seed(self.SEED)
fq_linear = FakeQuantizedLinear(
256,
688,
bias=False,
activation_config=FakeQuantizeConfig(8, "per_token", symmetric=False),
weight_config=FakeQuantizeConfig(4, group_size=group_size),
)

def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""
Baseline for int8 dynamic per token asymmetric + int4 per group symmetric quant.
"""
# activations
(s, zp) = _choose_qparams_per_token_asymmetric(x, torch.float32, torch.int32)
(qmin, qmax) = _get_qmin_qmax(8)
x_fq = _fake_quantize_per_token(x, s, zp, qmin, qmax)

# weights
(s, zp) = get_group_qparams_symmetric(weight, 4, group_size, torch.float32)
zp = zp.to(torch.int32)
(qmin, qmax) = _get_qmin_qmax(4)
w_fq = _fake_quantize_per_channel_group(weight, s, zp, qmin, qmax, group_size)
return F.linear(x_fq, w_fq)

# Compare linear values
torch.manual_seed(self.SEED)
x = torch.randn(100, 256)
x2 = copy.deepcopy(x)
fq_out = fq_linear(x)
baseline_out = linear_forward_8da4w(x2, fq_linear.weight)
torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_fake_quantized_linear_4w(self):
"""
Test that we can express int4 weight only (tinygemm) with `FakeQuantizedLinear`.
"""
group_size = 128
weight_config = FakeQuantizeConfig(
bit_width=4,
group_size=group_size,
symmetric=False,
zero_point_domain=ZeroPointDomain.FLOAT,
)
torch.manual_seed(self.SEED)
fq_linear = FakeQuantizedLinear(
256,
688,
bias=False,
activation_config=None,
weight_config=weight_config,
)

def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""
Baseline for int4 weight only fake quantization that simulates the tinygemm kernel.
"""
(qmin, qmax) = _get_qmin_qmax(4, symmetric=False)
(s, zp) = get_groupwise_affine_qparams(weight, 4, group_size, torch.float32)
zp = zp.to(torch.int32)
w_fq = _fake_quantize_per_channel_group(
weight, s, zp, qmin, qmax, group_size, zero_point_domain=ZeroPointDomain.FLOAT,
)
return F.linear(x, w_fq)

# Compare linear values
torch.manual_seed(self.SEED)
x = torch.randn(100, 256)
x2 = copy.deepcopy(x)
fq_out = fq_linear(x)
baseline_out = linear_forward_4w(x2, fq_linear.weight)
torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0)


if __name__ == "__main__":
unittest.main()
Loading
Loading