From 0900dcd9d4f61ab23a443827e16abb63ddffd117 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Wed, 6 Oct 2021 11:18:23 +0200 Subject: [PATCH 01/10] Add quantization to QA scripts --- scripts/question_answering/models.py | 19 +++++-- scripts/question_answering/run_squad.py | 72 ++++++++++++++++++++++++- 2 files changed, 85 insertions(+), 6 deletions(-) diff --git a/scripts/question_answering/models.py b/scripts/question_answering/models.py index b2ef10640b..742f6eb429 100644 --- a/scripts/question_answering/models.py +++ b/scripts/question_answering/models.py @@ -180,6 +180,8 @@ def __init__(self, backbone, units=768, layer_norm_eps=1E-12, dropout_prob=0.1, self.answerable_scores.add(nn.Dense(2, flatten=False, weight_initializer=weight_initializer, bias_initializer=bias_initializer)) + self.quantized_backbone = None + self.quantized = False def get_start_logits(self, contextual_embedding, p_mask): """ @@ -287,10 +289,14 @@ def forward(self, tokens, token_types, valid_length, p_mask, start_position): Shape (batch_size, sequence_length) answerable_logits """ + backbone_net = self.backbone + if self.quantized: + backbone_net = self.quantized_backbone + if self.use_segmentation: - contextual_embeddings = self.backbone(tokens, token_types, valid_length) + contextual_embeddings = backbone_net(tokens, token_types, valid_length) else: - contextual_embeddings = self.backbone(tokens, valid_length) + contextual_embeddings = backbone_net(tokens, valid_length) start_logits = self.get_start_logits(contextual_embeddings, p_mask) end_logits = self.get_end_logits(contextual_embeddings, np.expand_dims(start_position, axis=1), @@ -337,11 +343,16 @@ def inference(self, tokens, token_types, valid_length, p_mask, The answerable logits. Here 0 --> answerable and 1 --> not answerable. Shape (batch_size, sequence_length, 2) """ + backbone_net = self.backbone + if self.quantized: + backbone_net = self.quantized_backbone + # Shape (batch_size, sequence_length, C) if self.use_segmentation: - contextual_embeddings = self.backbone(tokens, token_types, valid_length) + contextual_embeddings = backbone_net(tokens, token_types, valid_length) else: - contextual_embeddings = self.backbone(tokens, valid_length) + contextual_embeddings = backbone_net(tokens, valid_length) + start_logits = self.get_start_logits(contextual_embeddings, p_mask) # The shape of start_top_index will be (..., start_top_n) start_top_logits, start_top_index = mx.npx.topk(start_logits, k=start_top_n, axis=-1, diff --git a/scripts/question_answering/run_squad.py b/scripts/question_answering/run_squad.py index 521ee15a47..4e18dd1dfe 100644 --- a/scripts/question_answering/run_squad.py +++ b/scripts/question_answering/run_squad.py @@ -147,9 +147,9 @@ def parse_args(): parser.add_argument('--max_saved_ckpt', type=int, default=5, help='The maximum number of saved checkpoints') parser.add_argument('--dtype', type=str, default='float32', - help='Data type used for evaluation. Either float32 or float16. When you ' + help='Data type used for evaluation. Either float32, float16 or int8. When you ' 'use --dtype float16, amp will be turned on in the training phase and ' - 'fp16 will be used in evaluation.') + 'fp16 will be used in evaluation. For now int8 data type is supported on CPU only.') args = parser.parse_args() return args @@ -815,6 +815,71 @@ def predict_extended(original_feature, assert len(nbest_json) >= 1 return not_answerable_score, nbest[0][0], nbest_json +def quantize_and_calibrate(net, dataloader): + class QuantizationDataLoader(mx.gluon.data.DataLoader): + def __init__(self, dataloader, use_segmentation): + self._dataloader = dataloader + self._iter = None + self._use_segmentation = use_segmentation + + def __iter__(self): + self._iter = iter(self._dataloader) + return self + + def __next__(self): + batch = next(self._iter) + if self._use_segmentation: + return [batch.data, batch.segment_ids, batch.valid_length] + else: + return [batch.data, batch.valid_length] + + def __del__(self): + del(self._dataloader) + + class BertLayerCollector(mx.contrib.quantization.CalibrationCollector): + """Saves layer output min and max values in a dict with layer names as keys. + The collected min and max values will be directly used as thresholds for quantization. + """ + def __init__(self, clip_min, clip_max): + super(BertLayerCollector, self).__init__() + self.clip_min = clip_min + self.clip_max = clip_max + + def collect(self, name, op_name, arr): + """Callback function for collecting min and max values from an NDArray.""" + if name not in self.include_layers: + return + print(name) + arr = arr.copyto(mx.cpu()).asnumpy() + min_range = np.min(arr) + max_range = np.max(arr) + + if (op_name.find("npi_copy") != -1 or op_name.find("LayerNorm") != -1) and max_range > self.clip_max: + max_range = self.clip_max + if op_name.find('Dropout') != -1 and min_range < self.clip_min: + print(name, op_name) + min_range = self.clip_min + + if name in self.min_max_dict: + cur_min_max = self.min_max_dict[name] + self.min_max_dict[name] = (min(cur_min_max[0], min_range), + max(cur_min_max[1], max_range)) + else: + self.min_max_dict[name] = (min_range, max_range) + + calib_data = QuantizationDataLoader(dataloader, net.use_segmentation) + net.quantized_backbone = mx.contrib.quant.quantize_net(net.backbone, quantized_dtype='auto', + exclude_layers=None, + exclude_layers_match=None, + calib_data=calib_data, + calib_mode='custom', + LayerOutputCollector=BertLayerCollector(clip_min=-50, clip_max=10), + num_calib_batches=10, + ctx=mx.current_context(), + logger=logging.getLogger()) + net.quantized = True + return net + def evaluate(args, last=True): store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm( @@ -860,6 +925,9 @@ def eval_validation(ckpt_name, best_eval): num_workers=0, shuffle=False) + if args.dtype == 'int8': + quantize_and_calibrate(qa_net, dev_dataloader) + log_interval = args.eval_log_interval all_results = [] epoch_tic = time.time() From bc9ce0f899079a3cf1a8d80bbb1fb746dd3c69b5 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Wed, 6 Oct 2021 13:52:33 +0200 Subject: [PATCH 02/10] fix --- scripts/question_answering/run_squad.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/scripts/question_answering/run_squad.py b/scripts/question_answering/run_squad.py index 4e18dd1dfe..bc8a1bf99e 100644 --- a/scripts/question_answering/run_squad.py +++ b/scripts/question_answering/run_squad.py @@ -849,13 +849,13 @@ def collect(self, name, op_name, arr): """Callback function for collecting min and max values from an NDArray.""" if name not in self.include_layers: return - print(name) arr = arr.copyto(mx.cpu()).asnumpy() min_range = np.min(arr) max_range = np.max(arr) if (op_name.find("npi_copy") != -1 or op_name.find("LayerNorm") != -1) and max_range > self.clip_max: max_range = self.clip_max + if op_name.find('Dropout') != -1 and min_range < self.clip_min: print(name, op_name) min_range = self.clip_min @@ -893,9 +893,10 @@ def evaluate(args, last=True): logging.info( 'Srarting inference without horovod on the first node on device {}'.format( str(ctx_l))) + network_dtype = args.dtype if args.dtype != 'int8' else 'float32' cfg, tokenizer, qa_net, use_segmentation = get_network( - args.model_name, ctx_l, args.classifier_dropout, dtype=args.dtype) + args.model_name, ctx_l, args.classifier_dropout, dtype=network_dtype) if args.dtype == 'float16': qa_net.cast('float16') qa_net.hybridize() @@ -1067,6 +1068,11 @@ def eval_validation(ckpt_name, best_eval): if __name__ == '__main__': os.environ['MXNET_GPU_MEM_POOL_TYPE'] = 'Round' args = parse_args() + if args.dtype == 'int8': + ctx_l = parse_ctx(args.gpus) + if ctx_l[0] != mx.cpu() or len(ctx_l) != 1: + raise ValueError("Evaluation on int8 data type is supported only for CPU for now") + if args.do_train: if args.dtype == 'float16': # Initialize amp if it's fp16 training From bba1525cc4ac848aca3ed452dbc15f43d8e53afb Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Fri, 8 Oct 2021 09:36:20 +0200 Subject: [PATCH 03/10] Remove quantize bool field --- scripts/question_answering/models.py | 5 ++--- scripts/question_answering/run_squad.py | 19 +++++++++---------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/scripts/question_answering/models.py b/scripts/question_answering/models.py index 742f6eb429..db1cce2e00 100644 --- a/scripts/question_answering/models.py +++ b/scripts/question_answering/models.py @@ -181,7 +181,6 @@ def __init__(self, backbone, units=768, layer_norm_eps=1E-12, dropout_prob=0.1, weight_initializer=weight_initializer, bias_initializer=bias_initializer)) self.quantized_backbone = None - self.quantized = False def get_start_logits(self, contextual_embedding, p_mask): """ @@ -290,7 +289,7 @@ def forward(self, tokens, token_types, valid_length, p_mask, start_position): answerable_logits """ backbone_net = self.backbone - if self.quantized: + if self.quantized_backbone != None: backbone_net = self.quantized_backbone if self.use_segmentation: @@ -344,7 +343,7 @@ def inference(self, tokens, token_types, valid_length, p_mask, Shape (batch_size, sequence_length, 2) """ backbone_net = self.backbone - if self.quantized: + if self.quantized_backbone != None: backbone_net = self.quantized_backbone # Shape (batch_size, sequence_length, C) diff --git a/scripts/question_answering/run_squad.py b/scripts/question_answering/run_squad.py index bc8a1bf99e..c7c7e866a5 100644 --- a/scripts/question_answering/run_squad.py +++ b/scripts/question_answering/run_squad.py @@ -869,15 +869,14 @@ def collect(self, name, op_name, arr): calib_data = QuantizationDataLoader(dataloader, net.use_segmentation) net.quantized_backbone = mx.contrib.quant.quantize_net(net.backbone, quantized_dtype='auto', - exclude_layers=None, - exclude_layers_match=None, - calib_data=calib_data, - calib_mode='custom', - LayerOutputCollector=BertLayerCollector(clip_min=-50, clip_max=10), - num_calib_batches=10, - ctx=mx.current_context(), - logger=logging.getLogger()) - net.quantized = True + exclude_layers=None, + exclude_layers_match=None, + calib_data=calib_data, + calib_mode='custom', + LayerOutputCollector=BertLayerCollector(clip_min=-50, clip_max=10), + num_calib_batches=10, + ctx=mx.current_context(), + logger=logging.getLogger()) return net @@ -891,7 +890,7 @@ def evaluate(args, last=True): return ctx_l = parse_ctx(args.gpus) logging.info( - 'Srarting inference without horovod on the first node on device {}'.format( + 'Starting inference without horovod on the first node on device {}'.format( str(ctx_l))) network_dtype = args.dtype if args.dtype != 'int8' else 'float32' From f2b5043608cbc68c0b67fbdcf2a3a3ef85363921 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Tue, 12 Oct 2021 15:23:57 +0800 Subject: [PATCH 04/10] Fix electra large accuracy --- scripts/question_answering/run_squad.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/scripts/question_answering/run_squad.py b/scripts/question_answering/run_squad.py index c7c7e866a5..da848abaae 100644 --- a/scripts/question_answering/run_squad.py +++ b/scripts/question_answering/run_squad.py @@ -868,15 +868,29 @@ def collect(self, name, op_name, arr): self.min_max_dict[name] = (min_range, max_range) calib_data = QuantizationDataLoader(dataloader, net.use_segmentation) + model_name = args.model_name + # disable specific layers in some models for the sake of accuracy + + if model_name == 'google_albert_base_v2': + logging.warn(f"Currently quantized {model_name} shows significant accuracy drop which is not fixed yet") + + exclude_layers_map = {"google_electra_large": + ["sg_mkldnn_fully_connected_eltwise_2", "sg_mkldnn_fully_connected_eltwise_14", + "sg_mkldnn_fully_connected_eltwise_18", "sg_mkldnn_fully_connected_eltwise_22", + "sg_mkldnn_fully_connected_eltwise_26" + ]} + exclude_layers = None + if model_name in exclude_layers_map.keys(): + exclude_layers = exclude_layers_map[model_name] + net.quantized_backbone = mx.contrib.quant.quantize_net(net.backbone, quantized_dtype='auto', - exclude_layers=None, + exclude_layers=exclude_layers, exclude_layers_match=None, calib_data=calib_data, calib_mode='custom', LayerOutputCollector=BertLayerCollector(clip_min=-50, clip_max=10), num_calib_batches=10, - ctx=mx.current_context(), - logger=logging.getLogger()) + ctx=mx.cpu()) return net From 94bf297517e132b08ed794df9c89d8a23579be49 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Thu, 2 Dec 2021 16:55:55 +0100 Subject: [PATCH 05/10] Update mkldnn to onednn --- scripts/question_answering/run_squad.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/scripts/question_answering/run_squad.py b/scripts/question_answering/run_squad.py index da848abaae..6c651daa1e 100644 --- a/scripts/question_answering/run_squad.py +++ b/scripts/question_answering/run_squad.py @@ -357,7 +357,7 @@ def get_squad_features(args, tokenizer, segment): tokenizer=tokenizer, is_training=is_training), data_examples) logging.info('Done! Time spent:{:.2f} seconds'.format(time.time() - start)) - with open(data_cache_path, 'w') as f: + with open(data_cache_path, 'w', encoding='utf-8') as f: for feature in data_features: f.write(feature.to_json() + '\n') @@ -875,14 +875,13 @@ def collect(self, name, op_name, arr): logging.warn(f"Currently quantized {model_name} shows significant accuracy drop which is not fixed yet") exclude_layers_map = {"google_electra_large": - ["sg_mkldnn_fully_connected_eltwise_2", "sg_mkldnn_fully_connected_eltwise_14", - "sg_mkldnn_fully_connected_eltwise_18", "sg_mkldnn_fully_connected_eltwise_22", - "sg_mkldnn_fully_connected_eltwise_26" + ["sg_onednn_fully_connected_eltwise_2", "sg_onednn_fully_connected_eltwise_14", + "sg_onednn_fully_connected_eltwise_18", "sg_onednn_fully_connected_eltwise_22", + "sg_onednn_fully_connected_eltwise_26" ]} exclude_layers = None if model_name in exclude_layers_map.keys(): exclude_layers = exclude_layers_map[model_name] - net.quantized_backbone = mx.contrib.quant.quantize_net(net.backbone, quantized_dtype='auto', exclude_layers=exclude_layers, exclude_layers_match=None, From d56fe41e256c44090e028f9c88cc661e7bb63c29 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Tue, 8 Mar 2022 16:53:10 +0100 Subject: [PATCH 06/10] Accuracy fix --- scripts/question_answering/run_squad.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/question_answering/run_squad.py b/scripts/question_answering/run_squad.py index 6c651daa1e..0f851f0a49 100644 --- a/scripts/question_answering/run_squad.py +++ b/scripts/question_answering/run_squad.py @@ -853,10 +853,10 @@ def collect(self, name, op_name, arr): min_range = np.min(arr) max_range = np.max(arr) - if (op_name.find("npi_copy") != -1 or op_name.find("LayerNorm") != -1) and max_range > self.clip_max: + if (name.find("sg_onednn_fully_connected_eltwise") != -1 or op_name.find("LayerNorm") != -1) \ + and max_range > self.clip_max: max_range = self.clip_max - - if op_name.find('Dropout') != -1 and min_range < self.clip_min: + elif name.find('sg_onednn_fully_connected') != -1 and min_range < self.clip_min: print(name, op_name) min_range = self.clip_min From 112071fb1ed46050c1031457148ef47ea229900e Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Wed, 9 Mar 2022 08:35:59 +0100 Subject: [PATCH 07/10] Add sphinx to dev requirments --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index baf44e6110..7b27865c84 100644 --- a/setup.py +++ b/setup.py @@ -133,6 +133,7 @@ def find_version(*file_paths): 'pylint_quotes', 'flake8', 'recommonmark', + 'sphinx>=1.5.5', 'sphinx-gallery', 'sphinx_rtd_theme', 'mxtheme', From 2b6997fc4fa1ea3ccd9c26933b6f5c020ae35848 Mon Sep 17 00:00:00 2001 From: bgawrych Date: Thu, 21 Apr 2022 11:46:10 +0200 Subject: [PATCH 08/10] remove print --- scripts/question_answering/run_squad.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/question_answering/run_squad.py b/scripts/question_answering/run_squad.py index 0f851f0a49..eadb08c2e6 100644 --- a/scripts/question_answering/run_squad.py +++ b/scripts/question_answering/run_squad.py @@ -857,7 +857,6 @@ def collect(self, name, op_name, arr): and max_range > self.clip_max: max_range = self.clip_max elif name.find('sg_onednn_fully_connected') != -1 and min_range < self.clip_min: - print(name, op_name) min_range = self.clip_min if name in self.min_max_dict: From cc1a87d215cfce8fd9bc09c80671a4ca9a388f31 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Thu, 11 Aug 2022 09:17:19 +0200 Subject: [PATCH 09/10] change quantize_mode to proper one --- scripts/question_answering/run_squad.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/question_answering/run_squad.py b/scripts/question_answering/run_squad.py index eadb08c2e6..c0f660a619 100644 --- a/scripts/question_answering/run_squad.py +++ b/scripts/question_answering/run_squad.py @@ -882,6 +882,7 @@ def collect(self, name, op_name, arr): if model_name in exclude_layers_map.keys(): exclude_layers = exclude_layers_map[model_name] net.quantized_backbone = mx.contrib.quant.quantize_net(net.backbone, quantized_dtype='auto', + quantize_mode='smart', exclude_layers=exclude_layers, exclude_layers_match=None, calib_data=calib_data, From 8dc7b0523498c2b9622a8bfbd8f169c04bd272ba Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Thu, 11 Aug 2022 09:28:27 +0200 Subject: [PATCH 10/10] fix round_to argument --- scripts/question_answering/run_squad.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/scripts/question_answering/run_squad.py b/scripts/question_answering/run_squad.py index c0f660a619..a3400573d4 100644 --- a/scripts/question_answering/run_squad.py +++ b/scripts/question_answering/run_squad.py @@ -120,7 +120,7 @@ def parse_args(): 'this will be truncated to this length. default is 64') parser.add_argument('--pre_shuffle_seed', type=int, default=100, help='Random seed for pre split shuffle') - parser.add_argument('--round_to', type=int, default=None, + parser.add_argument('--round_to', type=int, default=8, help='The length of padded sequences will be rounded up to be multiple' ' of this argument. When round to is set to 8, training throughput ' 'may increase for mixed precision training on GPUs with TensorCores.') @@ -195,13 +195,12 @@ def __init__(self, tokenizer, doc_stride, max_seq_length, max_query_length): self.sep_id = vocab.eos_id if 'sep_token' not in vocab.special_token_keys else vocab.sep_id # TODO(sxjscience) Consider to combine the NamedTuple and batchify functionality. - # Here, we use round_to=8 to improve the throughput. self.BatchifyFunction = bf.NamedTuple(ChunkFeature, {'qas_id': bf.List(), - 'data': bf.Pad(val=self.pad_id, round_to=8), + 'data': bf.Pad(val=self.pad_id, round_to=args.round_to), 'valid_length': bf.Stack(), - 'segment_ids': bf.Pad(round_to=8), - 'masks': bf.Pad(val=1, round_to=8), + 'segment_ids': bf.Pad(round_to=args.round_to), + 'masks': bf.Pad(val=1, round_to=args.round_to), 'is_impossible': bf.Stack(), 'gt_start': bf.Stack(), 'gt_end': bf.Stack(),