Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] adavanced ptq #112

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions mqbench/advanced_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,10 @@ def subgraph_reconstruction(subgraph, cached_inps, cached_oups, config):
a_para = []
for name, layer in subgraph.named_modules():
if isinstance(layer, _ADAROUND_SUPPORT_TYPE):
weight_quantizer = layer.weight_fake_quant
if hasattr(layer, 'weight_fake_quant'):
weight_quantizer = layer.weight_fake_quant
else:
continue
# assert isinstance(weight_quantizer, adaround_quantizer) is True
weight_quantizer.init(layer.weight.data, config.round_mode)
w_para += [weight_quantizer.alpha]
Expand All @@ -304,6 +307,8 @@ def subgraph_reconstruction(subgraph, cached_inps, cached_oups, config):
a_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(a_opt, T_max=config.max_count, eta_min=0.)
else:
a_opt, a_scheduler = None, None
if w_para == []:
return
w_opt = torch.optim.Adam(w_para)

loss_func = LossFunction(subgraph=subgraph, weight=config.weight, max_count=config.max_count, b_range=config.b_range,
Expand Down Expand Up @@ -633,7 +638,15 @@ def ptq_reconstruction(model: GraphModule, cali_data: list, config: dict, graph_
continue
logger.info('the node list is below!')
logger.info(layer_node_list)
fp32_module = fp32_modules[qnode2fpnode_dict[layer_node_list[-1]]]
if layer_node_list[-1] in qnode2fpnode_dict:
fp32_module = fp32_modules[qnode2fpnode_dict[layer_node_list[-1]]]
use_next_input = False
else:
out_node = layer_node_list[-1]
inp_node = list(out_node.users)[0]
qinp_node = qnode2fpnode_dict[inp_node]
fp32_module = fp32_modules[qinp_node]
use_next_input = True
fp32_all_inps = []
quant_all_inps = []
fp32_final_oups = None
Expand All @@ -642,15 +655,35 @@ def ptq_reconstruction(model: GraphModule, cali_data: list, config: dict, graph_
if all([arg in layer_node_list for arg in _flatten_args(_node.args) if isinstance(arg, torch.fx.Node)]):
continue
else:
fp32_inp_module = fp32_modules[qnode2fpnode_dict[_node]]
quant_module = quant_modules[_node]
if _node in quant_modules:
fp32_inp_module = fp32_modules[qnode2fpnode_dict[_node]]
quant_module = quant_modules[_node]
use_next_input_ = False
else:
_node = list(_node.users)[0]
fp32_inp_module = fp32_modules[qnode2fpnode_dict[_node]]
quant_module = quant_modules[_node]
use_next_input_ = True
# fp32 inps: [out_b1, out_b2, ...]
_, fp32_inps = save_inp_oup_data(fp32_model, None, fp32_inp_module, cali_data,
store_inp=False, store_oup=(config.prob < 1.0), keep_gpu=config.keep_gpu)
_, fp32_oups = save_inp_oup_data(fp32_model, None, fp32_module, cali_data,
store_inp=False, store_oup=(not out_is_cached), keep_gpu=config.keep_gpu)
_, quant_inps = save_inp_oup_data(quant_model, None, quant_module, cali_data,
store_inp=False, store_oup=True, keep_gpu=config.keep_gpu)
if use_next_input is False:
_, fp32_oups = save_inp_oup_data(fp32_model, None, fp32_module, cali_data,
store_inp=False, store_oup=(not out_is_cached), keep_gpu=config.keep_gpu)
else:
fp32_oups, _ = save_inp_oup_data(fp32_model, fp32_module, None, cali_data,
store_inp=(not out_is_cached), store_oup=False, keep_gpu=config.keep_gpu)
fp32_oups = sum(fp32_oups, [])
if use_next_input_ is False:
_, fp32_inps = save_inp_oup_data(fp32_model, None, fp32_inp_module, cali_data,
store_inp=False, store_oup=(config.prob < 1.0), keep_gpu=config.keep_gpu)
_, quant_inps = save_inp_oup_data(quant_model, None, quant_module, cali_data,
store_inp=False, store_oup=True, keep_gpu=config.keep_gpu)
else:
fp32_inps, _ = save_inp_oup_data(fp32_model, fp32_inp_module,None, cali_data,
store_inp=(config.prob < 1.0), store_oup=False, keep_gpu=config.keep_gpu)
quant_inps, _ = save_inp_oup_data(quant_model, quant_module, None, cali_data,
store_inp=True, store_oup=False, keep_gpu=config.keep_gpu)
fp32_inps = sum(fp32_inps, [])
quant_inps = sum(quant_inps, [])
fp32_all_inps.append(fp32_inps)
quant_all_inps.append(quant_inps)
if not out_is_cached:
Expand All @@ -674,3 +707,4 @@ def ptq_reconstruction(model: GraphModule, cali_data: list, config: dict, graph_
enable_quantization(quant_modules[node])
logger.info(f'set the node {node.target} in quant')
return quant_model