diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index cea659e61..fcab07c91 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -395,7 +395,10 @@ def test_eval_wrapper(self): # TODO: move to a separate test file @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") def test_quantized_tensor_subclass_8da4w(self): - from torchao.quantization.subclass import AffineQuantizedTensor + from torchao.quantization.subclass import ( + AffineQuantizedTensor, + LinearActQuantizedTensor, + ) from torchao.quantization.quant_primitives import MappingType import copy @@ -409,6 +412,7 @@ def test_quantized_tensor_subclass_8da4w(self): quant_max = 7 # TODO: make a general helper function? + # input settings def get_per_token_block_size(x): block_size = [] for i in range(len(x.shape)-1): @@ -421,13 +425,18 @@ def get_per_token_block_size(x): input_target_dtype = torch.int8 input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype) + def dynamic_quant(linear): + # note: order is important + linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps), requires_grad=False) + linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False) + m = ToyLinearModel().eval() m_copy = copy.deepcopy(m) example_inputs = m.example_inputs() - m.linear1.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(m.linear1.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False) - m.linear2.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(m.linear2.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False) - assert isinstance(m.linear1.weight, AffineQuantizedTensor) - assert isinstance(m.linear2.weight, AffineQuantizedTensor) + dynamic_quant(m.linear1) + dynamic_quant(m.linear2) + assert isinstance(m.linear1.weight, LinearActQuantizedTensor) + assert isinstance(m.linear2.weight, LinearActQuantizedTensor) # reference from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer @@ -461,9 +470,6 @@ def test_quantized_tensor_subclass_int4(self): preserve_zero = False zero_point_dtype = torch.bfloat16 - # weight only quantization - input_quant_func = None - # use 1024 so that we don't need padding m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") m_copy = copy.deepcopy(m) @@ -475,7 +481,6 @@ def to_quantized(weight): zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=ZeroPointDomain.FLOAT, - input_quant_func=input_quant_func, ) m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False) @@ -506,16 +511,13 @@ def test_quantized_tensor_subclass_int8(self): eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int64 - # weight only quantization - input_quant_func = None - m = ToyLinearModel().eval().to(torch.bfloat16) m_copy = copy.deepcopy(m) example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs())) def to_quantized(weight): block_size = (1, weight.shape[1]) - return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, input_quant_func=input_quant_func) + return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False) m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False) @@ -532,5 +534,63 @@ def to_quantized(weight): torch.testing.assert_close(res, ref, rtol=0.00001, atol=1e-2) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_quantized_tensor_subclass_int8_dyn_quant(self): + from torchao.quantization.subclass import AffineQuantizedTensor + from torchao.quantization.subclass import LinearActQuantizedTensor + from torchao.quantization.quant_primitives import MappingType + from torchao.quantization.quant_primitives import ZeroPointDomain + import copy + + # weight settings + mapping_type = MappingType.SYMMETRIC + def get_weight_block_size(x): + return (1, x.shape[1]) + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + + # input settings + def get_per_token_block_size(x): + block_size = list(x.shape) + for i in range(len(block_size)-1): + block_size[i] = 1 + return block_size + + input_mapping_type = MappingType.SYMMETRIC + input_target_dtype = torch.int8 + input_eps = 1e-5 + input_quant_min = -127 + input_quant_max = 127 + input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float) + + # use 1024 so that we don't need padding + m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") + m_copy = copy.deepcopy(m) + example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs())) + + def dynamic_quant(linear): + # note: order is important + linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, get_weight_block_size(linear.weight), target_dtype, eps=eps, zero_point_dtype=zero_point_dtype), requires_grad=False) + linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False) + + dynamic_quant(m.linear1) + dynamic_quant(m.linear2) + assert isinstance(m.linear1.weight, LinearActQuantizedTensor) + assert isinstance(m.linear2.weight, LinearActQuantizedTensor) + assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor) + assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor) + + # reference + from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors + change_linear_weights_to_int8_dqtensors(m_copy) + + res = m(*example_inputs) + ref = m_copy(*example_inputs) + + self.assertTrue(torch.equal(res, ref)) + + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 607cb7776..bc40ffeaf 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -21,7 +21,9 @@ quantize_affine, dequantize_affine, ZeroPointDomain, + MappingType, ) +from torchao.kernel.intmm import int_scaled_matmul from .utils import find_multiple from typing import Tuple, Optional, Callable @@ -36,6 +38,30 @@ aten = torch.ops.aten +def _aqt_is_int8(aqt): + """Check if an AffineQuantizedTensor is int8 quantized Tensor""" + return ( + aqt.int_data.dtype == torch.int8 and + aqt.quant_min is None or aqt.quant_min == -128 and + aqt.quant_max is None or aqt.quant_max == 127 + ) + +def _aqt_is_int8_reduced_range(aqt): + return ( + aqt.int_data.dtype == torch.int8 and + aqt.quant_min == -127 and + aqt.quant_max is None or aqt.quant_max == 127 + ) + +def _aqt_is_uint4(aqt): + """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" + # TODO: use torch.uint4 + return ( + aqt.int_data.dtype == torch.int32 and + aqt.quant_min is None or aqt.quant_min == 0 and + aqt.quant_max is None or aqt.quant_max == 15 + ) + class QuantizedLinearWeightBase(torch.Tensor): """ @@ -643,7 +669,6 @@ def __new__( quant_min: Optional[int] = None, quant_max: Optional[int] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, - input_quant_func: Optional[Callable] = None, dtype=None, # TODO: remove args and kwargs *args, @@ -670,7 +695,6 @@ def __init__( quant_min: Optional[int] = None, quant_max: Optional[int] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, - input_quant_func: Optional[Callable] = None, dtype=None, *args, **kwargs @@ -682,12 +706,11 @@ def __init__( self.quant_min = quant_min self.quant_max = quant_max self.zero_point_domain = zero_point_domain - self.input_quant_func = input_quant_func def __repr__(self): return ( f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, " - f"device={self.device}, dtype={self.dtype}, input_quant_func={self.input_quant_func}, requires_grad={self.requires_grad})" + f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" ) def dequantize(self, output_dtype=None): @@ -696,14 +719,14 @@ def dequantize(self, output_dtype=None): return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype) def __tensor_flatten__(self): - return ["int_data", "scales", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.input_quant_func, self.dtype] + return ["int_data", "scales", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] - block_size, shape, quant_min, quant_max, zero_point_domain, input_quant_func, dtype = tensor_attributes + block_size, shape, quant_min, quant_max, zero_point_domain, dtype = tensor_attributes return cls( int_data, scale, @@ -713,7 +736,6 @@ def __tensor_unflatten__( quant_min, quant_max, zero_point_domain, - input_quant_func=input_quant_func, dtype=dtype, strides=outer_stride, ) @@ -730,7 +752,6 @@ def from_float( eps = None, scale_dtype = None, zero_point_dtype = None, - input_quant_func = None, preserve_zero = True, zero_point_domain = ZeroPointDomain.INT, ): @@ -745,7 +766,6 @@ def from_float( quant_min, quant_max, zero_point_domain, - input_quant_func=input_quant_func, dtype=input_float.dtype ) @@ -759,27 +779,63 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): args[1], args[2] if len(args) > 2 else None, ) - if weight_qtensor.input_quant_func is None: - is_cuda = args[0].is_cuda - is_cpu = args[0].device == torch.device("cpu") - # weight only quantization - is_int8 = ( - weight_qtensor.int_data.dtype == torch.int8 and - weight_qtensor.quant_min is None or weight_qtensor.quant_min == -128 and - weight_qtensor.quant_max is None or weight_qtensor.quant_max == 127 - ) - is_uint4 = ( - weight_qtensor.int_data.dtype == torch.int32 and - weight_qtensor.quant_min == 0 and - weight_qtensor.quant_max == 15 - ) + is_cuda = weight_qtensor.is_cuda + is_cpu = weight_qtensor.device == torch.device("cpu") + if isinstance(weight_qtensor, AffineQuantizedTensor): + weight_is_int8 = _aqt_is_int8(weight_qtensor) + weight_is_uint4 = _aqt_is_uint4(weight_qtensor) + + if isinstance(input_tensor, AffineQuantizedTensor): + # if input tensor is quantized, either dispatch to the int8 mm kernel + # or just dequantize the input tensor + input_is_int8 = _aqt_is_int8_reduced_range(input_tensor) + input_tensor_dtype_is_expected = input_tensor.dtype in [ + torch.float, + torch.bfloat16 + ] + if ( + is_cuda and + input_is_int8 and + input_tensor_dtype_is_expected + ): + # + # 1. do the matrix form of dot(X_i, W_j) + # + # + # 2. rescale the output + # + # in cases with large matrices, y_dot_int32 can grow sufficiently + # large that y_dot_int32 * a float16 scale is greater than the maximum + # value of a float 16, (which results in a value of inf even if multiplying + # by the other scale would bring it within the expected range) + + x_vals_int8 = input_tensor.int_data + x_scales = input_tensor.scale + w_vals_int8_t = weight_qtensor.int_data.contiguous().t() + w_scales = weight_qtensor.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1)) + + y = (y_dot_scaled * w_scales).reshape( + *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] + ) + + # can downcast only at the very end + output_dtype = input_tensor.dtype + y = y.to(output_dtype) + if bias is not None: + y += bias + return y + else: + input_tensor = input_tensor.dequantize() + # weight only quantization # TODO: enable cpu and mps path as well # TODO: make sure weight dimension matches the expectation of the int4mm kernel # TODO: move this to TinygemmAffineQuantizedTensor if ( is_cuda and - is_uint4 and + weight_is_uint4 and weight_qtensor.dtype == torch.bfloat16 and len(weight_qtensor.shape) == 2 and weight_qtensor.block_size[0] == 1 and @@ -796,7 +852,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scales_and_zeros) elif ( is_cpu and - is_int8 and + weight_is_int8 and len(weight_qtensor.shape) == 2 and len(weight_qtensor.block_size) == 2 and weight_qtensor.block_size[0] == 1 and @@ -805,18 +861,16 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): # TODO: enable mps path as well # per channel int8 weight only quantizated mm return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale) + else: + weight_tensor = weight_qtensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) else: - # dynamic quantization - input_tensor = weight_qtensor.input_quant_func(input_tensor) - input_tensor = input_tensor.dequantize() - weight_tensor = weight_qtensor.dequantize() - return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + if isinstance(input_tensor, AffineQuantizedTensor): + input_tensor = input_tensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - try: - with torch._C.DisableTorchFunctionSubclass(): - return func(*args, **kwargs) - except: - print(f"ERR: subclass doesn't implement {func}") + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) def _get_to_kwargs(self, *args, **kwargs): @@ -844,7 +898,6 @@ def to(self, *args, **kwargs): self.quant_min, self.quant_max, self.zero_point_domain, - self.input_quant_func, **kwargs, ) @@ -858,7 +911,6 @@ def _apply_fn_to_data(self, fn): self.quant_min, self.quant_max, self.zero_point_domain, - self.input_quant_func, dtype=self.dtype, ) @@ -900,16 +952,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs): args[1], None if len(args) == 2 else args[2], ) - if weight_qtensor.input_quant_func is not None: - # dynamic quantization - input_tensor = weight_qtensor.input_quant_func(input_tensor) - input_tensor = input_tensor.dequantize() weight_tensor = weight_qtensor.dequantize() return func(input_tensor, weight_tensor, bias) - if (func is aten.detach.default or - func is aten.clone.default or - func is aten._to_copy.default): + if func is aten.detach.default: return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) @@ -933,3 +979,126 @@ def __torch_dispatch__(cls, func, types, args, kwargs): kwargs, args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), ) + + raise NotImplementedError( + f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported" + ) + + +class LinearActQuantizedTensor(torch.Tensor): + """ + Applies activation quantization for linear operator + """ + def __new__( + cls, + original_weight_tensor: torch.Tensor, + input_quant_func: Callable, + ): + kwargs = {} + dtype = original_weight_tensor.dtype + kwargs["dtype"] = dtype + kwargs["requires_grad"] = False + shape = original_weight_tensor.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + original_weight_tensor: torch.Tensor, + input_quant_func: Callable, + ): + self.original_weight_tensor = original_weight_tensor + self.input_quant_func = input_quant_func + + def __tensor_flatten__(self): + return ["original_weight_tensor"], [self.input_quant_func] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + original_weight_tensor = tensor_data_dict["original_weight_tensor"] + input_quant_func = tensor_attributes + return cls( + original_weight_tensor, + input_quant_func, + ) + + @classmethod + def from_float( + cls, + input_float, + input_quant_func, + ): + return cls( + input_float, + input_quant_func, + ) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + kwargs = {} if kwargs is None else kwargs + + if func is torch.nn.functional.linear: + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + if isinstance(weight_tensor, LinearActQuantizedTensor): + input_quant_func = weight_tensor.input_quant_func + original_weight_tensor = weight_tensor.original_weight_tensor + aqt = input_quant_func(input_tensor) + return torch.nn.functional.linear(aqt, original_weight_tensor, bias) + + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.original_weight_tensor), + self.input_quant_func, + ) + + def __torch_dispatch__(cls, func, types, args, kwargs): + if ( + func in [aten.mm.default, aten.addmm.default] + and args[0].is_floating_point() + ): + if func == aten.addmm.default: + assert args[1].shape[-1] == args[2].shape[0], ( + f"need mat1 shape: {args[1].shape} final" + f"dim to match mat2 shape: {args[2].shape} first dim " + ) + input_tensor, weight_qtensor, bias = ( + args[1], + args[2], + args[0], + ) + aqt = self.input_quant_func(input_tensor) + return func(bias, aqt, weight_tensor) + else: + assert args[0].shape[-1] == args[1].shape[0], ( + f"need mat1 shape: {args[0].shape} final dim" + f"to match mat2 shape: {args[1].shape} first dim" + ) + input_tensor, weight_qtensor, bias = ( + args[0], + args[1], + None if len(args) == 2 else args[2], + ) + aqt = self.input_quant_func(input_tensor) + return func(aqt, weight_tensor, bias) + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + raise NotImplementedError( + f"LinearActQuantizedTensor dispatch: attempting to run {func}, this is not supported" + )