diff --git a/scripts/question_answering/run_squad.py b/scripts/question_answering/run_squad.py index 521ee15a47..01e1ae1f65 100644 --- a/scripts/question_answering/run_squad.py +++ b/scripts/question_answering/run_squad.py @@ -126,6 +126,8 @@ def parse_args(): 'may increase for mixed precision training on GPUs with TensorCores.') parser.add_argument('--overwrite_cache', action='store_true', help='Whether to overwrite the feature cache.') + parser.add_argument('--sort_input_data', action='store_true', + help='Whether to sort input data before evaluation.') # Evaluation hyperparameters parser.add_argument('--start_top_n', type=int, default=5, help='Number of start-position candidates') @@ -837,6 +839,8 @@ def evaluate(args, last=True): logging.info('Prepare dev data') dev_features = get_squad_features(args, tokenizer, segment='dev') + if args.sort_input_data: + dev_features.sort(key=lambda x: x.qas_id) dev_data_path = os.path.join(args.data_dir, 'dev-v{}.json'.format(args.version)) dataset_processor = SquadDatasetProcessor(tokenizer=tokenizer, doc_stride=args.doc_stride, @@ -848,6 +852,8 @@ def evaluate(args, last=True): chunk_features = dataset_processor.process_sample(feature) dev_all_chunk_features.extend(chunk_features) dev_chunk_feature_ptr.append(dev_chunk_feature_ptr[-1] + len(chunk_features)) + if args.sort_input_data: + dev_all_chunk_features.sort(key=lambda x: x.valid_length, reverse=True) def eval_validation(ckpt_name, best_eval): """ @@ -912,6 +918,9 @@ def eval_validation(ckpt_name, best_eval): all_predictions = collections.OrderedDict() all_nbest_json = collections.OrderedDict() no_answer_score_json = collections.OrderedDict() + if args.sort_input_data: + all_results.sort(key=lambda x: x.qas_id) + dev_all_chunk_features.sort(key=lambda x: x.qas_id) for index, (left_index, right_index) in enumerate(zip(dev_chunk_feature_ptr[:-1], dev_chunk_feature_ptr[1:])): chunked_features = dev_all_chunk_features[left_index:right_index]