Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jst/add groupwise quantization #150

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sparsebit/quantization/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
class Granularity(Enum):
LAYERWISE = 0
CHANNELWISE = 1
GROUPWISE = 2


class QuantTarget(Enum):
Expand Down Expand Up @@ -44,6 +45,10 @@ def get_qscheme(qscheme):
return torch.per_channel_symmetric
if qscheme == "per-channel-affine":
return torch.per_channel_affine
if qscheme == "per-group-symmetric":
return "per-group-symmetric"
if qscheme == "per-group-affine":
return "per-group-affine"
raise TypeError(
"only support a qscheme equals to per-[tensor/channel]-[affine/symmetric] , not {}".format(
qscheme
Expand Down
27 changes: 15 additions & 12 deletions sparsebit/quantization/observers/aciq.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,18 @@ def __init__(self, config, qdesc):
self.gaus_const = (0.5 * 0.35) * (1 + (math.pi * math.log(4)) ** 0.5)

def calc_laplace_minmax(self):
if self.is_perchannel:
data = self.data_cache.get_data_for_calibration(Granularity.CHANNELWISE)
data = self.data_cache.get_data_for_calibration(self.granularity)
if self.granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]:
b = torch.mean(torch.abs(data - data.mean(1).unsqueeze(1)), dim=1)
else:
data = self.data_cache.get_data_for_calibration(Granularity.LAYERWISE)
elif self.granularity == Granularity.LAYERWISE:
b = torch.mean(torch.abs(data - data.mean()))
else:
raise NotImplementedError
self.data_cache.reset()
is_half_range = data.min() >= 0
if (
self.qdesc.scheme in [torch.per_channel_affine, torch.per_tensor_affine]
self.qdesc.scheme
in [torch.per_channel_affine, torch.per_tensor_affine, "per-group-affine"]
and is_half_range
):
max_val = self.alpha_laplace_positive[self.qdesc.bit] * b
Expand All @@ -85,25 +87,26 @@ def calc_laplace_minmax(self):
def calc_gaus_minmax(self):
if self.qdesc.target == QuantTarget.FEATURE:
batch_size = self.data_cache.get_batch_size()
if self.is_perchannel:
data = self.data_cache.get_data_for_calibration(Granularity.CHANNELWISE)
data = self.data_cache.get_data_for_calibration(self.granularity)
if self.granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]:
max_val = data.max(axis=1).values
min_val = data.min(axis=1).values
else:
data = self.data_cache.get_data_for_calibration(Granularity.LAYERWISE)
elif Granularity.LAYERWISE:
max_val = data.max()
min_val = data.min()
self.data_cache.get_batch_size
else:
raise NotImplementedError
self.data_cache.reset()
is_half_range = data.min() >= 0
num_elements = data.numel()
num_elements = data[0].numel()
if self.qdesc.target == QuantTarget.FEATURE:
num_elements /= batch_size
std = ((max_val - min_val) * self.gaus_const) / (
(2 * math.log(num_elements)) ** 0.5
)
if (
self.qdesc.scheme in [torch.per_channel_affine, torch.per_tensor_affine]
self.qdesc.scheme
in [torch.per_channel_affine, torch.per_tensor_affine, "per-group-affine"]
and is_half_range
):
max_val = self.alpha_gaus_positive[self.qdesc.bit] * std
Expand Down
75 changes: 62 additions & 13 deletions sparsebit/quantization/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ def __init__(self, qdesc):
self._data_cache = []

def update(self, data):
self._data_cache.append(data)
if self.ch_axis != 0:
self._data_cache.append(data.transpose(self.ch_axis, 0))
else:
self._data_cache.append(data)

def reset(self):
self._data_cache = []
Expand All @@ -23,16 +26,42 @@ def get_data_for_calibration(self, granularity: Granularity):
assert granularity in [
Granularity.LAYERWISE,
Granularity.CHANNELWISE,
], "only layerwise or channelwise quantization are supported now!"
if granularity == Granularity.CHANNELWISE:
data = torch.cat(self._data_cache, dim=self.qdesc.ch_axis)
if self.qdesc.ch_axis != 0:
data = data.transpose(0, self.qdesc.ch_axis)
data = data.flatten(1)
elif granularity == Granularity.LAYERWISE:
data = torch.cat([d.reshape(-1) for d in self._data_cache], axis=0)
else:
raise NotImplementedError
Granularity.GROUPWISE,
], "only layerwise, channelwise and groupwise quantization are supported now!"
if granularity == Granularity.LAYERWISE:
data = torch.cat([d.reshape(1, -1) for d in self._data_cache], axis=1)
elif granularity == Granularity.CHANNELWISE:
data = torch.cat(
[d.reshape(d.shape[0], -1) for d in self._data_cache], axis=1
)
elif granularity == Granularity.GROUPWISE:
if self.target == QuantTarget.FEATURE: # feature group on channel dim
assert (
self._data_cache[0].shape[0] <= self.group_size
or self._data_cache[0].shape[0] % self.group_size == 0
), "group size must be divided by channel num! got {} and {} instead".format(
self.group_size, self._data_cache[0].shape[0]
)
group_num = max(self._data_cache[0].shape[0] // self.group_size, 1)
if group_num == 1:
self.qdesc.set_group_size = self._data_cache[0].shape[0]
data = torch.cat(
[d.reshape(group_num, -1) for d in self._data_cache], axis=1
)
else: # weight group on ic dim
assert (
self._data_cache[0].shape[1] <= self.group_size
or self._data_cache[0].shape[1] % self.group_size == 0
), "group size must be divided by ic num! got {} and {} instead".format(
self.group_size, self._data_cache[0].shape[1]
)
group_num = max(self._data_cache[0].shape[1] // self.group_size, 1)
if group_num == 1:
self.qdesc.set_group_size = self._data_cache[0].shape[1]
data = torch.cat(
[d.reshape(d.shape[0] * group_num, -1) for d in self._data_cache],
axis=1,
)
return data

def get_batch_size(self):
Expand All @@ -44,6 +73,18 @@ def get_data_cache(self):
assert len(self._data_cache), "No data cached!"
return self._data_cache

@property
def target(self):
return self.qdesc.target

@property
def group_size(self):
return self.qdesc.group_size

@property
def ch_axis(self):
return self.qdesc.ch_axis


class Observer(nn.Module):
def __init__(self, config, qdesc):
Expand Down Expand Up @@ -79,9 +120,17 @@ def calc_qparams_with_minmax(self, min_val, max_val):
return scale, zero_point

@property
def is_perchannel(self):
return self.qdesc.is_perchannel
def granularity(self):
return self.qdesc.granularity

@property
def is_symmetric(self):
return self.qdesc.is_symmetric

@property
def target(self):
return self.qdesc.target

@property
def group_size(self):
return self.qdesc.group_size
11 changes: 5 additions & 6 deletions sparsebit/quantization/observers/kl_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,8 @@ def __init__(self, config, qdesc):
self.bins = 2048

def calc_minmax(self):
if self.is_perchannel:
data = self.data_cache.get_data_for_calibration(
Granularity.CHANNELWISE
).cpu()
data = self.data_cache.get_data_for_calibration(self.granularity).cpu()
if self.granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]:
channel = data.shape[0]
abs_max = data.abs().max(axis=1).values
_min = torch.empty(channel)
Expand All @@ -131,8 +129,7 @@ def calc_minmax(self):
_max[c] = th[c]
self.max_val = _max.to(self.device)
self.min_val = _min.to(self.device)
else:
data = self.data_cache.get_data_for_calibration(Granularity.LAYERWISE).cpu()
elif self.granularity == Granularity.LAYERWISE:
abs_max = data.abs().max()
th = get_best_threshold(
data=data,
Expand All @@ -147,5 +144,7 @@ def calc_minmax(self):
if data.min() < 0
else torch.zeros(1).to(self.device)
)
else:
raise NotImplementedError
self.data_cache.reset()
return self.min_val, self.max_val
10 changes: 5 additions & 5 deletions sparsebit/quantization/observers/minmax.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from sparsebit.quantization.observers import Observer as BaseObserver
from sparsebit.quantization.observers import register_observer
from sparsebit.quantization.common import Granularity
from sparsebit.quantization.common import Granularity, QuantTarget


@register_observer
Expand All @@ -12,13 +12,13 @@ def __init__(self, config, qdesc):
super(Observer, self).__init__(config, qdesc)

def calc_minmax(self):
if self.is_perchannel:
data = self.data_cache.get_data_for_calibration(Granularity.CHANNELWISE)
data = self.data_cache.get_data_for_calibration(self.granularity)
if self.granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]:
max_val = data.max(axis=1).values
min_val = data.min(axis=1).values
else:
data = self.data_cache.get_data_for_calibration(Granularity.LAYERWISE)
else: # Granularity.LAYERWISE
min_val, max_val = data.min(), data.max()

self.data_cache.reset()
self.min_val = min_val.to(self.device)
self.max_val = max_val.to(self.device)
Expand Down
5 changes: 4 additions & 1 deletion sparsebit/quantization/observers/moving_average.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from sparsebit.quantization.observers import Observer as BaseObserver
from sparsebit.quantization.observers import register_observer
from sparsebit.quantization.common import QuantTarget
from sparsebit.quantization.common import Granularity, QuantTarget


@register_observer
Expand All @@ -14,6 +14,9 @@ def __init__(self, config, qdesc):
hasattr(config.OBSERVER, "MOVING_AVERAGE")
and self.qdesc.target == QuantTarget.FEATURE
), "Moving_average observer only support feature observing!"
assert (
self.granularity == Granularity.LAYERWISE
), "Moving_average observer only support layerwise quantization!"
self.ema_ratio = config.OBSERVER.MOVING_AVERAGE.EMA_RATIO

def calc_minmax(self):
Expand Down
23 changes: 14 additions & 9 deletions sparsebit/quantization/observers/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,23 @@ def __init__(self, config, qdesc):
self.alpha = config.OBSERVER.PERCENTILE.ALPHA

def calc_minmax(self, data_c_first):
if self.is_perchannel:
if self.granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]:
max_val = data_c_first.max(axis=1).values
min_val = data_c_first.min(axis=1).values
else:
elif self.granularity == Granularity.LAYERWISE:
min_val, max_val = data_c_first.min(), data_c_first.max()
else:
raise NotImplementedError
self.min_val = min_val.to(self.device)
self.max_val = max_val.to(self.device)
return self.min_val, self.max_val

def calc_qparams(self):
data_c_first = self.data_cache.get_data_for_calibration(Granularity.CHANNELWISE)
data_c_first = self.data_cache.get_data_for_calibration(self.granularity)
self.data_cache.reset()
min_val, max_val = self.calc_minmax(data_c_first)
x_f = data_c_first.to(self.device)
if self.is_perchannel:
if self.granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]:
best_scale = torch.tensor(
[1.0 for _ in range(data_c_first.shape[0])], device=self.device
)
Expand All @@ -40,24 +42,27 @@ def calc_qparams(self):
loss_min = torch.tensor(
[1e10 for _ in range(data_c_first.shape[0])], device=self.device
)
else:
elif self.granularity == Granularity.LAYERWISE:
best_scale, best_zero_point = None, None
loss_min = 1e10
else:
raise NotImplementedError
for i in range(80):
cur_min_val = min_val * (1.0 - (i * 0.01))
cur_max_val = max_val * (1.0 - (i * 0.01))
scale, zero_point = self.calc_qparams_with_minmax(cur_min_val, cur_max_val)
x_dq = STE.apply(x_f, scale, zero_point, self.qdesc, Backend.VIRTUAL)
if self.is_perchannel:
loss = mse_loss(x_f, x_dq, is_perchannel=True)
loss = mse_loss(x_f, x_dq, self.granularity)
if self.granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]:
best_scale[loss < loss_min] = scale[loss < loss_min]
best_zero_point[loss < loss_min] = zero_point[loss < loss_min]
loss_min[loss < loss_min] = loss[loss < loss_min]
else:
loss = mse_loss(x_f, x_dq, is_perchannel=False)
elif self.granularity == Granularity.LAYERWISE:
if loss < loss_min:
loss_min = loss
best_scale = scale
best_zero_point = zero_point
else:
raise NotImplementedError
assert len(self.data_cache) == 0, "free data cache after calc_qparams"
return best_scale, best_zero_point
8 changes: 1 addition & 7 deletions sparsebit/quantization/observers/percentile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,7 @@ def __init__(self, config, qdesc):
self.alpha = config.OBSERVER.PERCENTILE.ALPHA

def calc_minmax(self):

if self.is_perchannel:
data = self.data_cache.get_data_for_calibration(Granularity.CHANNELWISE)
else:
data = self.data_cache.get_data_for_calibration(
Granularity.LAYERWISE
).reshape(1, -1)
data = self.data_cache.get_data_for_calibration(self.granularity)
self.data_cache.reset()
channel = data.shape[0]

Expand Down
11 changes: 8 additions & 3 deletions sparsebit/quantization/observers/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
def mse_loss(pred, tgt, is_perchannel=False):
if is_perchannel:
from sparsebit.quantization.common import Granularity


def mse_loss(pred, tgt, granularity: Granularity):
if granularity in [Granularity.CHANNELWISE, Granularity.GROUPWISE]:
return ((pred - tgt) ** 2).mean(-1)
else:
elif granularity == Granularity.LAYERWISE:
return ((pred - tgt) ** 2).mean()
else:
raise NotImplementedError
2 changes: 2 additions & 0 deletions sparsebit/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
_C.W.QUANTIZER.TYPE = "uniform"
_C.W.QUANTIZER.DISABLE = False
_C.W.QUANTIZER.BIT = -1
_C.W.QUANTIZER.GROUP_SIZE = -1
_C.W.OBSERVER = CN()
_C.W.OBSERVER.TYPE = "MINMAX" # "MINMAX"/"MSE"/"PERCENTILE"/"KL_HISTOGRAM"
_C.W.OBSERVER.PERCENTILE = CN()
Expand All @@ -32,6 +33,7 @@
_C.A.QUANTIZER.TYPE = "uniform"
_C.A.QUANTIZER.DISABLE = False
_C.A.QUANTIZER.BIT = -1
_C.A.QUANTIZER.GROUP_SIZE = -1
_C.A.QUANTIZER.PACT = CN()
_C.A.QUANTIZER.PACT.ALPHA_VALUE = 10
_C.A.OBSERVER = CN()
Expand Down
4 changes: 2 additions & 2 deletions sparsebit/quantization/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ def ch_axis(self):
return self.observer.ch_axis

@property
def is_perchannel(self):
return self.qdesc.is_perchannel
def granularity(self):
return self.qdesc.granularity

@property
def is_symmetric(self):
Expand Down
Loading