diff --git a/sparsebit/quantization/common.py b/sparsebit/quantization/common.py index 03b1eb0..d8efe0f 100644 --- a/sparsebit/quantization/common.py +++ b/sparsebit/quantization/common.py @@ -5,6 +5,7 @@ class Granularity(Enum): LAYERWISE = 0 CHANNELWISE = 1 + GROUPWISE = 2 class QuantTarget(Enum): @@ -44,6 +45,10 @@ def get_qscheme(qscheme): return torch.per_channel_symmetric if qscheme == "per-channel-affine": return torch.per_channel_affine + if qscheme == "per-group-symmetric": + return "per-group-symmetric" + if qscheme == "per-group-affine": + return "per-group-affine" raise TypeError( "only support a qscheme equals to per-[tensor/channel]-[affine/symmetric] , not {}".format( qscheme diff --git a/sparsebit/quantization/observers/aciq.py b/sparsebit/quantization/observers/aciq.py index af46f73..929416c 100644 --- a/sparsebit/quantization/observers/aciq.py +++ b/sparsebit/quantization/observers/aciq.py @@ -63,16 +63,18 @@ def __init__(self, config, qdesc): self.gaus_const = (0.5 * 0.35) * (1 + (math.pi * math.log(4)) ** 0.5) def calc_laplace_minmax(self): - if self.is_perchannel: - data = self.data_cache.get_data_for_calibration(Granularity.CHANNELWISE) + data = self.data_cache.get_data_for_calibration(self.granularity) + if self.granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]: b = torch.mean(torch.abs(data - data.mean(1).unsqueeze(1)), dim=1) - else: - data = self.data_cache.get_data_for_calibration(Granularity.LAYERWISE) + elif self.granularity == Granularity.LAYERWISE: b = torch.mean(torch.abs(data - data.mean())) + else: + raise NotImplementedError self.data_cache.reset() is_half_range = data.min() >= 0 if ( - self.qdesc.scheme in [torch.per_channel_affine, torch.per_tensor_affine] + self.qdesc.scheme + in [torch.per_channel_affine, torch.per_tensor_affine, "per-group-affine"] and is_half_range ): max_val = self.alpha_laplace_positive[self.qdesc.bit] * b @@ -85,25 +87,26 @@ def calc_laplace_minmax(self): def calc_gaus_minmax(self): if self.qdesc.target == QuantTarget.FEATURE: batch_size = self.data_cache.get_batch_size() - if self.is_perchannel: - data = self.data_cache.get_data_for_calibration(Granularity.CHANNELWISE) + data = self.data_cache.get_data_for_calibration(self.granularity) + if self.granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]: max_val = data.max(axis=1).values min_val = data.min(axis=1).values - else: - data = self.data_cache.get_data_for_calibration(Granularity.LAYERWISE) + elif Granularity.LAYERWISE: max_val = data.max() min_val = data.min() - self.data_cache.get_batch_size + else: + raise NotImplementedError self.data_cache.reset() is_half_range = data.min() >= 0 - num_elements = data.numel() + num_elements = data[0].numel() if self.qdesc.target == QuantTarget.FEATURE: num_elements /= batch_size std = ((max_val - min_val) * self.gaus_const) / ( (2 * math.log(num_elements)) ** 0.5 ) if ( - self.qdesc.scheme in [torch.per_channel_affine, torch.per_tensor_affine] + self.qdesc.scheme + in [torch.per_channel_affine, torch.per_tensor_affine, "per-group-affine"] and is_half_range ): max_val = self.alpha_gaus_positive[self.qdesc.bit] * std diff --git a/sparsebit/quantization/observers/base.py b/sparsebit/quantization/observers/base.py index e10d555..184054e 100644 --- a/sparsebit/quantization/observers/base.py +++ b/sparsebit/quantization/observers/base.py @@ -10,7 +10,10 @@ def __init__(self, qdesc): self._data_cache = [] def update(self, data): - self._data_cache.append(data) + if self.ch_axis != 0: + self._data_cache.append(data.transpose(self.ch_axis, 0)) + else: + self._data_cache.append(data) def reset(self): self._data_cache = [] @@ -23,16 +26,42 @@ def get_data_for_calibration(self, granularity: Granularity): assert granularity in [ Granularity.LAYERWISE, Granularity.CHANNELWISE, - ], "only layerwise or channelwise quantization are supported now!" - if granularity == Granularity.CHANNELWISE: - data = torch.cat(self._data_cache, dim=self.qdesc.ch_axis) - if self.qdesc.ch_axis != 0: - data = data.transpose(0, self.qdesc.ch_axis) - data = data.flatten(1) - elif granularity == Granularity.LAYERWISE: - data = torch.cat([d.reshape(-1) for d in self._data_cache], axis=0) - else: - raise NotImplementedError + Granularity.GROUPWISE, + ], "only layerwise, channelwise and groupwise quantization are supported now!" + if granularity == Granularity.LAYERWISE: + data = torch.cat([d.reshape(1, -1) for d in self._data_cache], axis=1) + elif granularity == Granularity.CHANNELWISE: + data = torch.cat( + [d.reshape(d.shape[0], -1) for d in self._data_cache], axis=1 + ) + elif granularity == Granularity.GROUPWISE: + if self.target == QuantTarget.FEATURE: # feature group on channel dim + assert ( + self._data_cache[0].shape[0] <= self.group_size + or self._data_cache[0].shape[0] % self.group_size == 0 + ), "group size must be divided by channel num! got {} and {} instead".format( + self.group_size, self._data_cache[0].shape[0] + ) + group_num = max(self._data_cache[0].shape[0] // self.group_size, 1) + if group_num == 1: + self.qdesc.set_group_size = self._data_cache[0].shape[0] + data = torch.cat( + [d.reshape(group_num, -1) for d in self._data_cache], axis=1 + ) + else: # weight group on ic dim + assert ( + self._data_cache[0].shape[1] <= self.group_size + or self._data_cache[0].shape[1] % self.group_size == 0 + ), "group size must be divided by ic num! got {} and {} instead".format( + self.group_size, self._data_cache[0].shape[1] + ) + group_num = max(self._data_cache[0].shape[1] // self.group_size, 1) + if group_num == 1: + self.qdesc.set_group_size = self._data_cache[0].shape[1] + data = torch.cat( + [d.reshape(d.shape[0] * group_num, -1) for d in self._data_cache], + axis=1, + ) return data def get_batch_size(self): @@ -44,6 +73,18 @@ def get_data_cache(self): assert len(self._data_cache), "No data cached!" return self._data_cache + @property + def target(self): + return self.qdesc.target + + @property + def group_size(self): + return self.qdesc.group_size + + @property + def ch_axis(self): + return self.qdesc.ch_axis + class Observer(nn.Module): def __init__(self, config, qdesc): @@ -79,9 +120,17 @@ def calc_qparams_with_minmax(self, min_val, max_val): return scale, zero_point @property - def is_perchannel(self): - return self.qdesc.is_perchannel + def granularity(self): + return self.qdesc.granularity @property def is_symmetric(self): return self.qdesc.is_symmetric + + @property + def target(self): + return self.qdesc.target + + @property + def group_size(self): + return self.qdesc.group_size diff --git a/sparsebit/quantization/observers/kl_histogram.py b/sparsebit/quantization/observers/kl_histogram.py index e7c15f2..247ada1 100644 --- a/sparsebit/quantization/observers/kl_histogram.py +++ b/sparsebit/quantization/observers/kl_histogram.py @@ -102,10 +102,8 @@ def __init__(self, config, qdesc): self.bins = 2048 def calc_minmax(self): - if self.is_perchannel: - data = self.data_cache.get_data_for_calibration( - Granularity.CHANNELWISE - ).cpu() + data = self.data_cache.get_data_for_calibration(self.granularity).cpu() + if self.granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]: channel = data.shape[0] abs_max = data.abs().max(axis=1).values _min = torch.empty(channel) @@ -131,8 +129,7 @@ def calc_minmax(self): _max[c] = th[c] self.max_val = _max.to(self.device) self.min_val = _min.to(self.device) - else: - data = self.data_cache.get_data_for_calibration(Granularity.LAYERWISE).cpu() + elif self.granularity == Granularity.LAYERWISE: abs_max = data.abs().max() th = get_best_threshold( data=data, @@ -147,5 +144,7 @@ def calc_minmax(self): if data.min() < 0 else torch.zeros(1).to(self.device) ) + else: + raise NotImplementedError self.data_cache.reset() return self.min_val, self.max_val diff --git a/sparsebit/quantization/observers/minmax.py b/sparsebit/quantization/observers/minmax.py index 0c470c7..a0d1289 100644 --- a/sparsebit/quantization/observers/minmax.py +++ b/sparsebit/quantization/observers/minmax.py @@ -1,7 +1,7 @@ import torch from sparsebit.quantization.observers import Observer as BaseObserver from sparsebit.quantization.observers import register_observer -from sparsebit.quantization.common import Granularity +from sparsebit.quantization.common import Granularity, QuantTarget @register_observer @@ -12,13 +12,13 @@ def __init__(self, config, qdesc): super(Observer, self).__init__(config, qdesc) def calc_minmax(self): - if self.is_perchannel: - data = self.data_cache.get_data_for_calibration(Granularity.CHANNELWISE) + data = self.data_cache.get_data_for_calibration(self.granularity) + if self.granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]: max_val = data.max(axis=1).values min_val = data.min(axis=1).values - else: - data = self.data_cache.get_data_for_calibration(Granularity.LAYERWISE) + else: # Granularity.LAYERWISE min_val, max_val = data.min(), data.max() + self.data_cache.reset() self.min_val = min_val.to(self.device) self.max_val = max_val.to(self.device) diff --git a/sparsebit/quantization/observers/moving_average.py b/sparsebit/quantization/observers/moving_average.py index c72a745..0f9bb84 100644 --- a/sparsebit/quantization/observers/moving_average.py +++ b/sparsebit/quantization/observers/moving_average.py @@ -1,7 +1,7 @@ import torch from sparsebit.quantization.observers import Observer as BaseObserver from sparsebit.quantization.observers import register_observer -from sparsebit.quantization.common import QuantTarget +from sparsebit.quantization.common import Granularity, QuantTarget @register_observer @@ -14,6 +14,9 @@ def __init__(self, config, qdesc): hasattr(config.OBSERVER, "MOVING_AVERAGE") and self.qdesc.target == QuantTarget.FEATURE ), "Moving_average observer only support feature observing!" + assert ( + self.granularity == Granularity.LAYERWISE + ), "Moving_average observer only support layerwise quantization!" self.ema_ratio = config.OBSERVER.MOVING_AVERAGE.EMA_RATIO def calc_minmax(self): diff --git a/sparsebit/quantization/observers/mse.py b/sparsebit/quantization/observers/mse.py index a97b6d1..f5e758c 100644 --- a/sparsebit/quantization/observers/mse.py +++ b/sparsebit/quantization/observers/mse.py @@ -16,21 +16,23 @@ def __init__(self, config, qdesc): self.alpha = config.OBSERVER.PERCENTILE.ALPHA def calc_minmax(self, data_c_first): - if self.is_perchannel: + if self.granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]: max_val = data_c_first.max(axis=1).values min_val = data_c_first.min(axis=1).values - else: + elif self.granularity == Granularity.LAYERWISE: min_val, max_val = data_c_first.min(), data_c_first.max() + else: + raise NotImplementedError self.min_val = min_val.to(self.device) self.max_val = max_val.to(self.device) return self.min_val, self.max_val def calc_qparams(self): - data_c_first = self.data_cache.get_data_for_calibration(Granularity.CHANNELWISE) + data_c_first = self.data_cache.get_data_for_calibration(self.granularity) self.data_cache.reset() min_val, max_val = self.calc_minmax(data_c_first) x_f = data_c_first.to(self.device) - if self.is_perchannel: + if self.granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]: best_scale = torch.tensor( [1.0 for _ in range(data_c_first.shape[0])], device=self.device ) @@ -40,24 +42,27 @@ def calc_qparams(self): loss_min = torch.tensor( [1e10 for _ in range(data_c_first.shape[0])], device=self.device ) - else: + elif self.granularity == Granularity.LAYERWISE: best_scale, best_zero_point = None, None loss_min = 1e10 + else: + raise NotImplementedError for i in range(80): cur_min_val = min_val * (1.0 - (i * 0.01)) cur_max_val = max_val * (1.0 - (i * 0.01)) scale, zero_point = self.calc_qparams_with_minmax(cur_min_val, cur_max_val) x_dq = STE.apply(x_f, scale, zero_point, self.qdesc, Backend.VIRTUAL) - if self.is_perchannel: - loss = mse_loss(x_f, x_dq, is_perchannel=True) + loss = mse_loss(x_f, x_dq, self.granularity) + if self.granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]: best_scale[loss < loss_min] = scale[loss < loss_min] best_zero_point[loss < loss_min] = zero_point[loss < loss_min] loss_min[loss < loss_min] = loss[loss < loss_min] - else: - loss = mse_loss(x_f, x_dq, is_perchannel=False) + elif self.granularity == Granularity.LAYERWISE: if loss < loss_min: loss_min = loss best_scale = scale best_zero_point = zero_point + else: + raise NotImplementedError assert len(self.data_cache) == 0, "free data cache after calc_qparams" return best_scale, best_zero_point diff --git a/sparsebit/quantization/observers/percentile.py b/sparsebit/quantization/observers/percentile.py index 310a402..5c0ecf1 100644 --- a/sparsebit/quantization/observers/percentile.py +++ b/sparsebit/quantization/observers/percentile.py @@ -14,13 +14,7 @@ def __init__(self, config, qdesc): self.alpha = config.OBSERVER.PERCENTILE.ALPHA def calc_minmax(self): - - if self.is_perchannel: - data = self.data_cache.get_data_for_calibration(Granularity.CHANNELWISE) - else: - data = self.data_cache.get_data_for_calibration( - Granularity.LAYERWISE - ).reshape(1, -1) + data = self.data_cache.get_data_for_calibration(self.granularity) self.data_cache.reset() channel = data.shape[0] diff --git a/sparsebit/quantization/observers/utils.py b/sparsebit/quantization/observers/utils.py index b1b349c..6c397c9 100644 --- a/sparsebit/quantization/observers/utils.py +++ b/sparsebit/quantization/observers/utils.py @@ -1,5 +1,10 @@ -def mse_loss(pred, tgt, is_perchannel=False): - if is_perchannel: +from sparsebit.quantization.common import Granularity + + +def mse_loss(pred, tgt, granularity: Granularity): + if granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]: return ((pred - tgt) ** 2).mean(-1) - else: + elif granularity == Granularity.LAYERWISE: return ((pred - tgt) ** 2).mean() + else: + raise NotImplementedError diff --git a/sparsebit/quantization/quant_config.py b/sparsebit/quantization/quant_config.py index b46532d..aafc5ac 100644 --- a/sparsebit/quantization/quant_config.py +++ b/sparsebit/quantization/quant_config.py @@ -18,6 +18,7 @@ _C.W.QUANTIZER.TYPE = "uniform" _C.W.QUANTIZER.DISABLE = False _C.W.QUANTIZER.BIT = -1 +_C.W.QUANTIZER.GROUP_SIZE = -1 _C.W.OBSERVER = CN() _C.W.OBSERVER.TYPE = "MINMAX" # "MINMAX"/"MSE"/"PERCENTILE"/"KL_HISTOGRAM" _C.W.OBSERVER.PERCENTILE = CN() @@ -32,6 +33,7 @@ _C.A.QUANTIZER.TYPE = "uniform" _C.A.QUANTIZER.DISABLE = False _C.A.QUANTIZER.BIT = -1 +_C.A.QUANTIZER.GROUP_SIZE = -1 _C.A.QUANTIZER.PACT = CN() _C.A.QUANTIZER.PACT.ALPHA_VALUE = 10 _C.A.OBSERVER = CN() diff --git a/sparsebit/quantization/quantizers/base.py b/sparsebit/quantization/quantizers/base.py index ec14d77..0fe1f9f 100644 --- a/sparsebit/quantization/quantizers/base.py +++ b/sparsebit/quantization/quantizers/base.py @@ -115,8 +115,8 @@ def ch_axis(self): return self.observer.ch_axis @property - def is_perchannel(self): - return self.qdesc.is_perchannel + def granularity(self): + return self.qdesc.granularity @property def is_symmetric(self): diff --git a/sparsebit/quantization/quantizers/lsq.py b/sparsebit/quantization/quantizers/lsq.py index a88a7ea..b98650f 100644 --- a/sparsebit/quantization/quantizers/lsq.py +++ b/sparsebit/quantization/quantizers/lsq.py @@ -41,10 +41,12 @@ def calc_qparams(self): "Found data less than 0, reset quantizer scheme as symmetric" ) self.qdesc.set_symmetric(True) - if self.is_perchannel: + if self.granularity == Granularity.CHANNELWISE: scale = 2 * x_oc.abs().mean(axis=1) / math.sqrt(self.qdesc.qmax) - else: + elif self.granularity == Granularity.LAYERWISE: scale = 2 * x_oc.abs().mean() / math.sqrt(self.qdesc.qmax) + else: + raise NotImplementedError self.scale = nn.Parameter(self._broadcast_qparams(scale.to(self.device))) self.zero_point = self._broadcast_qparams(torch.zeros_like(self.scale)) self.init_params = True @@ -66,11 +68,13 @@ def _qparams_preprocess(self, x): return scale, zero_point def _forward(self, x, scale, zero_point): - if self.is_perchannel: + if self.granularity == Granularity.CHANNELWISE: num_perchannel = x.numel() / x.shape[self.qdesc.ch_axis] gs_ratio = 1.0 / math.sqrt(num_perchannel * self.qdesc.qmax) - else: + if self.granularity == Granularity.LAYERWISE: gs_ratio = 1.0 / math.sqrt(x.numel() * self.qdesc.qmax) + else: + raise NotImplementedError scale = gs_scaling.apply(scale, gs_ratio) x_dq = STE.apply(x, scale, zero_point, self.qdesc, self.backend) return x_dq diff --git a/sparsebit/quantization/quantizers/lsq_plus.py b/sparsebit/quantization/quantizers/lsq_plus.py index 4ff4bb2..da1b428 100644 --- a/sparsebit/quantization/quantizers/lsq_plus.py +++ b/sparsebit/quantization/quantizers/lsq_plus.py @@ -22,7 +22,7 @@ def calc_qparams(self): if self.fake_fused: return self.scale, self.zero_point if not self.init_params: - if self.is_perchannel: + if self.granularity == Granularity.CHANNELWISE: x_oc = self.observer.data_cache.get_data_for_calibration( Granularity.CHANNELWISE ) @@ -39,7 +39,7 @@ def calc_qparams(self): self._broadcast_qparams(scale.to(self.device)) ) self.zero_point = self._broadcast_qparams(torch.zeros_like(self.scale)) - else: + elif granularity == Granularity.LAYERWISE: assert ( not self.is_symmetric ), "LSQ+ only support per-tensor-affine quant for activation" @@ -51,6 +51,8 @@ def calc_qparams(self): self.zero_point = nn.Parameter( self._broadcast_qparams(zero_point.to(self.device)) ) + else: + raise NotImplementedError self.init_params = True return self.scale, self.zero_point @@ -70,11 +72,13 @@ def _qparams_preprocess(self, x): return scale, zero_point def _forward(self, x, scale, zero_point): - if self.is_perchannel: + if self.granularity == Granularity.CHANNELWISE: num_perchannel = x.numel() / x.shape[self.qdesc.ch_axis] gs_ratio = 1.0 / math.sqrt(num_perchannel * self.qdesc.qmax) - else: + elif self.granularity == Granularity.LAYERWISE: gs_ratio = 1.0 / math.sqrt(x.numel() * self.qdesc.qmax) + else: + raise NotImplementedError scale = gs_scaling.apply(scale, gs_ratio) if zero_point.requires_grad: zero_point = gs_scaling.apply(zero_point, gs_ratio) diff --git a/sparsebit/quantization/quantizers/pact.py b/sparsebit/quantization/quantizers/pact.py index 69082df..8825d95 100644 --- a/sparsebit/quantization/quantizers/pact.py +++ b/sparsebit/quantization/quantizers/pact.py @@ -5,7 +5,7 @@ from sparsebit.quantization.quantizers import Quantizer as BaseQuantizer from sparsebit.quantization.quantizers import register_quantizer -from sparsebit.quantization.common import QuantTarget +from sparsebit.quantization.common import QuantTarget, Granularity from .quant_tensor import STE @@ -18,7 +18,9 @@ def __init__(self, config): assert ( self.qdesc.target == QuantTarget.FEATURE ), "PACT only support feature quantization" - assert not self.qdesc.is_perchannel, "PACT no yet supports per-channel" + assert ( + not self.granularity == Granularity.LAYERWISE + ), "PACT only supports per-tensor now!" self.init_alpha_value = config.QUANTIZER.PACT.ALPHA_VALUE def calc_qparams(self): diff --git a/sparsebit/quantization/quantizers/quant_descriptor.py b/sparsebit/quantization/quantizers/quant_descriptor.py index e516520..f2fd5de 100644 --- a/sparsebit/quantization/quantizers/quant_descriptor.py +++ b/sparsebit/quantization/quantizers/quant_descriptor.py @@ -1,5 +1,5 @@ import torch -from sparsebit.quantization.common import get_qscheme +from sparsebit.quantization.common import get_qscheme, Granularity class QuantDescriptor: @@ -8,26 +8,42 @@ def __init__(self, cfg): self._target = cfg.TARGET[0] self._scheme = get_qscheme(cfg.QSCHEME) self._bit = cfg.QUANTIZER.BIT + self._group_size = cfg.QUANTIZER.GROUP_SIZE + if self._group_size != -1: + assert cfg.QSCHEME in ["per-group-symmetric", "per-group-affine"] self._qmin, self._qmax, self._type = self.calc_qmin_qmax( self._bit, self._scheme ) self._ch_axis = self._set_channel_axis() self._bs_axis = self._set_batchsize_axis() - self.is_perchannel = ( - self._scheme == torch.per_channel_symmetric - or self._scheme == torch.per_channel_affine - ) + self.granularity = { + torch.per_channel_symmetric: Granularity.CHANNELWISE, + torch.per_channel_affine: Granularity.CHANNELWISE, + torch.per_tensor_symmetric: Granularity.LAYERWISE, + torch.per_tensor_affine: Granularity.LAYERWISE, + "per-group-symmetric": Granularity.GROUPWISE, + "per-group-affine": Granularity.GROUPWISE, + }[self._scheme] self.is_symmetric = ( self._scheme == torch.per_channel_symmetric or self._scheme == torch.per_tensor_symmetric + or self._scheme == "per-group-symmetric" ) def calc_qmin_qmax(self, bit, scheme): - if scheme in [torch.per_channel_symmetric, torch.per_tensor_symmetric]: + if scheme in [ + torch.per_channel_symmetric, + torch.per_tensor_symmetric, + "per-group-symmetric", + ]: qmin = -(2 ** (bit - 1)) qmax = 2 ** (bit - 1) - 1 _type = "int{}".format(bit) - elif scheme in [torch.per_channel_affine, torch.per_tensor_affine]: + elif scheme in [ + torch.per_channel_affine, + torch.per_tensor_affine, + "per-group-affine", + ]: qmin = 0 qmax = 2**bit - 1 _type = "uint{}".format(bit) @@ -61,14 +77,19 @@ def set_bit(self, bit): self._bit = bit self._qmin, self._qmax, self._type = self.calc_qmin_qmax(bit, self._scheme) + def set_group_size(self, group_size): + self._group_size = group_size + def set_symmetric(self, is_symmetric: bool): self.is_symmetric = is_symmetric self._scheme = { - (True, True): torch.per_channel_symmetric, - (True, False): torch.per_channel_affine, - (False, True): torch.per_tensor_symmetric, - (False, False): torch.per_tensor_affine, - }[(self.is_perchannel, self.is_symmetric)] + (Granularity.CHANNELWISE, True): torch.per_channel_symmetric, + (Granularity.CHANNELWISE, False): torch.per_channel_affine, + (Granularity.LAYERWISE, True): torch.per_tensor_symmetric, + (Granularity.LAYERWISE, False): torch.per_tensor_affine, + (Granularity.GROUPWISE, True): "per-group-symmetric", + (Granularity.GROUPWISE, False): "per-group-affine", + }[(self.granularity, self.is_symmetric)] self._qmin, self._qmax, self._type = self.calc_qmin_qmax( self._bit, self._scheme ) @@ -105,7 +126,11 @@ def ch_axis(self): def bs_axis(self): return self._bs_axis + @property + def group_size(self): + return self._group_size + def __repr__(self): - return self._type + "\t qmin: {} qmax: {}, qscheme: {}".format( - self.qmin, self.qmax, self.scheme + return self._type + "\t qmin: {} qmax: {}, qscheme: {}, group_size: {}".format( + self.qmin, self.qmax, self.scheme, self.group_size ) diff --git a/sparsebit/quantization/quantizers/quant_tensor.py b/sparsebit/quantization/quantizers/quant_tensor.py index bad3a1c..8439243 100644 --- a/sparsebit/quantization/quantizers/quant_tensor.py +++ b/sparsebit/quantization/quantizers/quant_tensor.py @@ -2,7 +2,7 @@ import numpy as np import torch import torch.nn as nn -from sparsebit.quantization.common import Backend +from sparsebit.quantization.common import Backend, Granularity, QuantTarget if torch.cuda.is_available(): from torch.utils.cpp_extension import load @@ -87,7 +87,7 @@ def backward(ctx, gout): if torch.cuda.is_available(): if x.dtype == torch.float16: # A workaround x = x.float() - if qdesc.is_perchannel: + if qdesc.granularity == Granularity.CHANNELWISE: gx, gs, gzp = fake_quant_kernel.quant_perchannel_backward( x.contiguous(), scale.contiguous(), @@ -98,7 +98,7 @@ def backward(ctx, gout): qdesc.ch_axis, 0, ) - else: + elif qdesc.granularity == Granularity.LAYERWISE: gx, gs, gzp = fake_quant_kernel.quant_pertensor_backward( x.contiguous(), scale, @@ -108,6 +108,8 @@ def backward(ctx, gout): qmax, 0, ) + else: + raise NotImplementedError gs = gs if scale.requires_grad else None gzp = gzp if zero_point.requires_grad else None else: @@ -136,7 +138,7 @@ def trt_fake_quant(x_f, scale, zero_point, qdesc): if torch.cuda.is_available() and "cuda" in x_f.device.type: if x_f.dtype == torch.float16: # A workaround x_f = x_f.float() - if qdesc.is_perchannel: + if qdesc.granularity == Granularity.CHANNELWISE: x_dq = fake_quant_kernel.quant_perchannel_forward( x_f.contiguous(), scale.contiguous(), @@ -146,10 +148,12 @@ def trt_fake_quant(x_f, scale, zero_point, qdesc): qdesc.ch_axis, 0, ) - else: + elif qdesc.granularity == Granularity.LAYERWISE: x_dq = fake_quant_kernel.quant_pertensor_forward( x_f.contiguous(), scale, zero_point, qmin, qmax, 0 ) + else: + raise NotImplementedError else: x_q = torch.clamp((x_f / scale).round(), qmin, qmax) x_dq = x_q * scale @@ -164,7 +168,7 @@ def ort_fake_quant(x_f, scale, zero_point, qdesc): if torch.cuda.is_available() and "cuda" in x_f.device.type: if x_f.dtype == torch.float16: # A workaround x_f = x_f.float() - if qdesc.is_perchannel: + if qdesc.granularity == Granularity.CHANNELWISE: x_dq = fake_quant_kernel.quant_perchannel_forward( x_f.contiguous(), scale.contiguous(), @@ -174,14 +178,48 @@ def ort_fake_quant(x_f, scale, zero_point, qdesc): qdesc.ch_axis, 0, ) - else: + elif qdesc.granularity == Granularity.LAYERWISE: x_dq = fake_quant_kernel.quant_pertensor_forward( x_f.contiguous(), scale, zero_point, qmin, qmax, 0 ) + elif qdesc.granularity == Granularity.GROUPWISE: + origin_shape = x_f.shape + grouped_shape = torch.Size([scale.numel(), -1]) + scale = scale.reshape(grouped_shape) + zero_point = zero_point.reshape(grouped_shape) + x_f = x_f.reshape(grouped_shape) + + x_dq = fake_quant_kernel.quant_perchannel_forward( + x_f.contiguous(), + scale.contiguous(), + zero_point.contiguous(), + qmin, + qmax, + 0, + 0, + ) + x_dq = x_dq.reshape(origin_shape) + else: + raise NotImplementedError else: - zp = zero_point.round() - x_q = torch.clamp((x_f / scale).round() + zp, qmin, qmax) - x_dq = (x_q - zp) * scale + if qdesc.granularity == Granularity.GROUPWISE: + zp = zero_point.round() + origin_shape = x_f.shape + if qdesc.target == QuantTarget.FEATURE: + grouped_shape = torch.Size([x_f.shape[0], scale.numel(), -1]) + else: + grouped_shape = torch.Size([scale.numel(), -1]) + scale = scale.reshape(grouped_shape) + zp = zp.reshape(grouped_shape) + x_f = x_f.reshape(grouped_shape) + x_q = torch.clamp((x_f / scale).round() + zp, qmin, qmax) + x_dq = (x_q - zp) * scale + x_dq = x_dq.reshape(origin_shape) + + else: + zp = zero_point.round() + x_q = torch.clamp((x_f / scale).round() + zp, qmin, qmax) + x_dq = (x_q - zp) * scale return x_dq