From 70d93e41319d2e13f74795f6f417ef73b9471456 Mon Sep 17 00:00:00 2001 From: Tomohide Shibata Date: Wed, 1 Jun 2022 17:27:19 +0900 Subject: [PATCH] Add fine-tuning --- README.md | 11 +- fine-tuning/README.md | 145 ++++++++ .../transformers-4.9.2_jglue-1.0.0.patch | 320 ++++++++++++++++++ fine-tuning/scripts/generate_results.py | 100 ++++++ 4 files changed, 571 insertions(+), 5 deletions(-) create mode 100644 fine-tuning/README.md create mode 100644 fine-tuning/patch/transformers-4.9.2_jglue-1.0.0.patch create mode 100644 fine-tuning/scripts/generate_results.py diff --git a/README.md b/README.md index 39075ec..12fcf17 100644 --- a/README.md +++ b/README.md @@ -195,8 +195,9 @@ When you use NICT BERT base or Waseda RoBERTa base models, the dataset text shou Please refer to [preprocess/morphological-analysis/README.md](/preprocess/morphological-analysis/README.md). The fine-tuning was performed using [the transformers -library](https://github.com/huggingface/transformers) provided by Hugging Face. The performance along -with human scores on the JGLUE dev set is shown below. +library](https://github.com/huggingface/transformers) provided by Hugging Face. See [fine-tuning/README.md](/fine-tuning/README.md) for details. + +The performance along with human scores on the JGLUE dev set is shown below. |Model|MARC-ja|JSTS|JNLI|JSQuAD|JCommonsenseQA| |-----|-------|-------|-------|-------|-------| @@ -209,10 +210,10 @@ with human scores on the JGLUE dev set is shown below. |NICT BERT base|0.958|0.903/0.867|0.902|**0.897**/**0.947**|0.823| |Waseda RoBERTa base|0.962|0.901/0.865|0.895|0.864/0.927|0.840| |Waseda RoBERTa large|0.954|**0.923**/**0.891**|**0.924**|0.884/0.940|**0.901**| -|XLM RoBERTa base|0.961|0.870/0.825|0.893|-/-|0.687| -|XLM RoBERTa large|**0.964**|0.915/0.882|0.919|-/-|0.840| - +|XLM RoBERTa base|0.961|0.870/0.825|0.893|-/-†|0.687| +|XLM RoBERTa large|**0.964**|0.915/0.882|0.919|-/-†|0.840| +†XLM RoBERTa base/large models use the unigram language model as a tokenizer and they are excluded from the JSQuAD evaluation because the token delimitation and the start/end of the answer span often do not match, resulting in poor performance. ## Leaderboard diff --git a/fine-tuning/README.md b/fine-tuning/README.md new file mode 100644 index 0000000..658fa77 --- /dev/null +++ b/fine-tuning/README.md @@ -0,0 +1,145 @@ +# Fine-tuning + +We used [the transformers library](https://github.com/huggingface/transformers) for our fine-tuning experiments, and modified the original codes to fit our datasets and experimental settings. We used v4.9.2, but other versions may work. + +```bash +# This directory is for installing transformers. This directory would be better outside the JGLUE repository.) +$ cd /somewhere/ +$ git clone https://github.com/huggingface/transformers.git -b v4.9.2 transformers-4.9.2 +$ cd transformers-4.9.2 +$ patch -p1 < /somewhere2/JGLUE/fine-tuning/patch/transformers-4.9.2_jglue-1.0.0.patch +$ pip install . +$ pip install -r examples/pytorch/text-classification/requirements.txt +$ pip install protobuf==3.19.1 tensorboard +# For the cl-tohoku/bert-base-japanese-v2 model +$ pip install fugashi unidic-lite +``` + +## Hyperparameters + +The following table lists hyperparameters used in our experiments. The numbers in curly brackets represent the range of possible values. The best hyperparameters were searched using the dev set. + +|Name|Value(s)| +|----|-------| +|learning rate|{5e-5, 3e-5, 2e-5}| +|epoch|{3, 4}| +|warmup ratio|0.1| +|max seq length|512 (MARC-ja), 128 (JSTS, JNLI), 384 (JSQuAD), 64 (JCommonsenseQA)| + +## Text classification and sentence pair classification tasks + +This section is for `MARC-ja`, `JSTS` and `JNLI`. When you fine-tune the `cl-tohoku/bert-base-japanese-v2` model for the `MARC-ja` dataset, run the following command: + +```bash +$ export OUTPUT_DIR=/path/to/output_marc +$ python /somewhere/transformers-4.9.2/examples/pytorch/text-classification/run_glue.py \ + --model_name_or_path cl-tohoku/bert-base-japanese-v2 \ + --metric_name sst2 \ + --do_train --do_eval --do_predict \ + --max_seq_length 512 \ + --per_device_train_batch_size 32 \ + --learning_rate 5e-05 \ + --num_train_epochs 4 \ + --output_dir $OUTPUT_DIR \ + --train_file ../datasets/marc_ja-v1.0/train-v1.0.json \ + --validation_file ../datasets/marc_ja-v1.0/valid-v1.0.json \ + --test_file ../datasets/marc_ja-v1.0/valid-v1.0.json \ + --use_fast_tokenizer False \ + --evaluation_strategy epoch \ + --save_steps 5000 \ + --warmup_ratio 0.1 +``` + +`--metric_name` option should be set according to a dataset as follows: +- MARC-ja: sst2 +- JSTS: stsb +- JNLI: wnli + +When you fine-tune the NICT BERT base or Waseda RoBERTa base/large models, please specify the word-segmented files with `--train_file`, `--validation_file` and `--test_file` options (see [preprocess/morphological-analysis/README.md](/preprocess/morphological-analysis/README.md) for how to perform word segmentation). For example, you use Waseda RoBERTa base/large models, these options are as follows: + +```bash + .. + --train_file ../datasets/marc_ja-v1.0_jumanpp/train-v1.0.json \ + --validation_file ../datasets/marc_ja-v1.0_jumanpp/valid-v1.0.json \ + -test_file ../datasets/marc_ja-v1.0_jumanpp/valid-v1.0.json \ + .. +``` + +The system prediction for the validation set is output to `$OUTPUT_DIR/predict_results_${metric_name}.txt`. +For the examination of the system prediction, the system prediction as well as the evaluation result can be output as follows: + +- MARC-ja +```bash +$ python scripts/generate_results.py \ + --system-predict-txt $OUTPUT_DIR/predict_results_sst2.txt \ + --input-file ../datasets/marc_ja-v1.0/valid-v1.0.json \ + --task-type single-sentence \ + --additional-column-name-string review_id > $OUTPUT_DIR/predict_eval_results.tsv +``` +- JSTS +```bash +$ python scripts/generate_results.py \ + --system-predict-txt $OUTPUT_DIR/predict_results_stsb.txt \ + --input-file ../datasets/jsts-v1.0/valid-v1.0.json \ + --task-type sentence-pair \ + --classification-type regression \ + --additional-column-name-string sentence_pair_id,yjcaptions_id > $OUTPUT_DIR/predict_eval_results.tsv +``` + +- JNLI +```bash +$ python scripts/generate_results.py \ + --system-predict-txt $OUTPUT_DIR/predict_results_wnli.txt \ + --input-file ../datasets/jnli-v1.0/valid-v1.0.json \ + --task-type sentence-pair \ + --additional-column-name-string sentence_pair_id,yjcaptions_id > $OUTPUT_DIR/predict_eval_results.tsv +``` +## QA: JSQuAD + +``` +$ export OUTPUT_DIR=/path/to/output_jsquad +$ python /somewhere/transformers-4.9.2/examples/legacy/question-answering/run_squad.py \ + --model_type bert \ + --model_name_or_path cl-tohoku/bert-base-japanese-v2 \ + --do_train --do_eval \ + --max_seq_length 384 \ + --learning_rate 5e-05 \ + --num_train_epochs 3 \ + --per_gpu_train_batch_size 32 \ + --per_gpu_eval_batch_size 32 \ + --output_dir $OUTPUT_DIR \ + --train_file ../datasets/jsquad-v1.0/train-v1.0.json \ + --predict_file ../datasets/jsquad-v1.0/valid-v1.0.json \ + --save_steps 5000 \ + --warmup_ratio 0.1 \ + --evaluate_prefix eval +``` + +## QA: JCommonsenseQA + +```bash +$ export OUTPUT_DIR=/path/to/output_jcommonsenseqa +$ python /somewhere/transformers-4.9.2/examples/pytorch/multiple-choice/run_swag.py \ + --model_name_or_path cl-tohoku/bert-base-japanese-v2 \ + --do_train --do_eval --do_predict \ + --max_seq_length 64 \ + --per_device_train_batch_size 64 \ + --learning_rate 5e-05 \ + --num_train_epochs 4 \ + --output_dir $OUTPUT_DIR \ + --train_file ../datasets/jcommonsenseqa-v1.0/train-v1.0.json \ + --validation_file ../datasets/jcommonsenseqa-v1.0/valid-v1.0.json \ + --test_file ../datasets/jcommonsenseqa-v1.0/valid-v1.0.json \ + --use_fast_tokenizer False \ + --evaluation_strategy epoch \ + --warmup_ratio 0.1 +``` + +For the examination of the system prediction, the system prediction as well as the evaluation result can be output as follows: +```bash +$ python scripts/generate_results.py \ + --system-predict-txt $OUTPUT_DIR/predict_results_valid.txt \ + --input-file ../datasets/jcommonsenseqa-v1.0/valid-v1.0.json \ + --task-type swag \ + --additional-column-name-string q_id > $OUTPUT_DIR/predict_eval_results.tsv +``` diff --git a/fine-tuning/patch/transformers-4.9.2_jglue-1.0.0.patch b/fine-tuning/patch/transformers-4.9.2_jglue-1.0.0.patch new file mode 100644 index 0000000..eb570d3 --- /dev/null +++ b/fine-tuning/patch/transformers-4.9.2_jglue-1.0.0.patch @@ -0,0 +1,320 @@ +diff --git a/examples/legacy/question-answering/run_squad.py b/examples/legacy/question-answering/run_squad.py +index fbf2ebd..041010e 100644 +--- a/examples/legacy/question-answering/run_squad.py ++++ b/examples/legacy/question-answering/run_squad.py +@@ -22,6 +22,9 @@ import logging + import os + import random + import timeit ++import json ++import math ++ + + import numpy as np + import torch +@@ -98,6 +101,10 @@ def train(args, train_dataset, model, tokenizer): + {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) ++ if args.warmup_ratio > 0: ++ args.warmup_steps = math.ceil(t_total * args.warmup_ratio) ++ logger.info("Warmup steps = %d", args.warmup_steps) ++ + scheduler = get_linear_schedule_with_warmup( + optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total + ) +@@ -600,6 +607,7 @@ def main(): + help="If > 0: set total number of training steps to perform. Override num_train_epochs.", + ) + parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") ++ parser.add_argument("--warmup_ratio", default=0.0, type=float, help="Linear warmup ratio.") + parser.add_argument( + "--n_best_size", + default=20, +@@ -626,6 +634,13 @@ def main(): + help="language id of input for language-specific xlm models (see tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)", + ) + ++ parser.add_argument( ++ "--evaluate_prefix", ++ default=None, ++ type=str, ++ help="evaluate prefix", ++ ) ++ + parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.") + parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") + parser.add_argument( +@@ -816,12 +831,15 @@ def main(): + model.to(args.device) + + # Evaluate +- result = evaluate(args, model, tokenizer, prefix=global_step) ++ result = evaluate(args, model, tokenizer, prefix=args.evaluate_prefix if args.evaluate_prefix is not None else global_step) + + result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items()) + results.update(result) + + logger.info("Results: {}".format(results)) ++ eval_json_file = os.path.join(args.output_dir, "{}_results.json".format(args.evaluate_prefix)) ++ with open(eval_json_file, "w") as writer: ++ writer.write(json.dumps(results, indent=4) + "\n") + + return results + +diff --git a/examples/legacy/question-answering/run_squad_trainer.py b/examples/legacy/question-answering/run_squad_trainer.py +index 7089326..e6a7032 100644 +--- a/examples/legacy/question-answering/run_squad_trainer.py ++++ b/examples/legacy/question-answering/run_squad_trainer.py +@@ -172,7 +172,7 @@ def main(): + trainer.save_model() + # For convenience, we also re-save the tokenizer to the same directory, + # so that you can share your model easily on huggingface.co/models =) +- if trainer.is_world_master(): ++ if trainer.is_world_process_zero(): + tokenizer.save_pretrained(training_args.output_dir) + + +diff --git a/examples/pytorch/multiple-choice/run_swag.py b/examples/pytorch/multiple-choice/run_swag.py +index 9446ce3..b013b02 100755 +--- a/examples/pytorch/multiple-choice/run_swag.py ++++ b/examples/pytorch/multiple-choice/run_swag.py +@@ -99,6 +99,10 @@ class DataTrainingArguments: + default=None, + metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, + ) ++ test_file: Optional[str] = field( ++ default=None, ++ metadata={"help": "An optional input prediction data file."}, ++ ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) +@@ -268,6 +272,8 @@ def main(): + data_files["train"] = data_args.train_file + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file ++ if data_args.test_file is not None: ++ data_files["test"] = data_args.test_file + extension = data_args.train_file.split(".")[-1] + raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) + else: +@@ -304,9 +310,8 @@ def main(): + ) + + # When using your own dataset or a different dataset from swag, you will probably need to change this. +- ending_names = [f"ending{i}" for i in range(4)] +- context_name = "sent1" +- question_header_name = "sent2" ++ ending_names = [f"choice{i}" for i in range(5)] ++ context_name = "question" + + if data_args.max_seq_length is None: + max_seq_length = tokenizer.model_max_length +@@ -326,10 +331,9 @@ def main(): + + # Preprocessing the datasets. + def preprocess_function(examples): +- first_sentences = [[context] * 4 for context in examples[context_name]] +- question_headers = examples[question_header_name] ++ first_sentences = [[context] * 5 for context in examples[context_name]] + second_sentences = [ +- [f"{header} {examples[end][i]}" for end in ending_names] for i, header in enumerate(question_headers) ++ [f"{examples[end][i]}" for end in ending_names] for i in range(len(examples[context_name])) + ] + + # Flatten out +@@ -345,7 +349,7 @@ def main(): + padding="max_length" if data_args.pad_to_max_length else False, + ) + # Un-flatten +- return {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()} ++ return {k: [v[i : i + 5] for i in range(0, len(v), 5)] for k, v in tokenized_examples.items()} + + if training_args.do_train: + if "train" not in raw_datasets: +@@ -375,6 +379,18 @@ def main(): + load_from_cache_file=not data_args.overwrite_cache, + ) + ++ if training_args.do_predict: ++ if "test" not in raw_datasets: ++ raise ValueError("--do_predict requires a test dataset") ++ predict_dataset = raw_datasets["test"] ++ with training_args.main_process_first(desc="test dataset map pre-processing"): ++ predict_dataset = predict_dataset.map( ++ preprocess_function, ++ batched=True, ++ num_proc=data_args.preprocessing_num_workers, ++ load_from_cache_file=not data_args.overwrite_cache, ++ ) ++ + # Data collator + data_collator = ( + default_data_collator +@@ -430,6 +446,21 @@ def main(): + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + ++ if training_args.do_predict: ++ logger.info("*** Predict ***") ++ ++ predictions = trainer.predict(predict_dataset) ++ prediction_list = np.argmax(predictions.predictions, axis=1) ++ output_predict_file = os.path.join(training_args.output_dir, f"predict_results_valid.txt") ++ ++ if trainer.is_world_process_zero(): ++ with open(output_predict_file, "w") as writer: ++ logger.info(f"***** Predict results *****") ++ writer.write("index\tprediction\n") ++ ++ for i, prediction in enumerate(prediction_list): ++ writer.write(f"{i}\t{prediction}\n") ++ + if training_args.push_to_hub: + trainer.push_to_hub( + finetuned_from=model_args.model_name_or_path, +diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py +index 257a593..de95ece 100755 +--- a/examples/pytorch/text-classification/run_glue.py ++++ b/examples/pytorch/text-classification/run_glue.py +@@ -80,6 +80,10 @@ class DataTrainingArguments: + default=None, + metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, + ) ++ metric_name: Optional[str] = field( ++ default=None, ++ metadata={"help": "The name of the metric for evaluation as a GLUE task name: " + ", ".join(task_to_keys.keys())}, ++ ) + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) +@@ -339,6 +343,8 @@ def main(): + # Preprocessing the raw_datasets + if data_args.task_name is not None: + sentence1_key, sentence2_key = task_to_keys[data_args.task_name] ++ elif data_args.metric_name is not None: ++ sentence1_key, sentence2_key = task_to_keys[data_args.metric_name] + else: + # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. + non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] +@@ -436,6 +442,8 @@ def main(): + # Get the metric function + if data_args.task_name is not None: + metric = load_metric("glue", data_args.task_name) ++ elif data_args.metric_name is not None: ++ metric = load_metric("glue", data_args.metric_name) + else: + metric = load_metric("accuracy") + +@@ -444,7 +452,7 @@ def main(): + def compute_metrics(p: EvalPrediction): + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) +- if data_args.task_name is not None: ++ if data_args.task_name is not None or data_args.metric_name is not None: + result = metric.compute(predictions=preds, references=p.label_ids) + if len(result) > 1: + result["combined_score"] = np.mean(list(result.values())).item() +@@ -519,7 +527,9 @@ def main(): + logger.info("*** Predict ***") + + # Loop to handle MNLI double evaluation (matched, mis-matched) +- tasks = [data_args.task_name] ++ # tasks = [data_args.task_name] ++ tasks = [data_args.metric_name] ++ + predict_datasets = [predict_dataset] + if data_args.task_name == "mnli": + tasks.append("mnli-mm") +diff --git a/src/transformers/data/metrics/squad_metrics.py b/src/transformers/data/metrics/squad_metrics.py +index f55e827..ceac867 100644 +--- a/src/transformers/data/metrics/squad_metrics.py ++++ b/src/transformers/data/metrics/squad_metrics.py +@@ -38,15 +38,20 @@ def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): +- regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) +- return re.sub(regex, " ", text) ++ return text.rstrip("。") ++ ++ # regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) ++ # return re.sub(regex, " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): +- exclude = set(string.punctuation) +- return "".join(ch for ch in text if ch not in exclude) ++ # do nothing ++ return text ++ ++ # exclude = set(string.punctuation) ++ # return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() +@@ -65,8 +70,10 @@ def compute_exact(a_gold, a_pred): + + + def compute_f1(a_gold, a_pred): +- gold_toks = get_tokens(a_gold) +- pred_toks = get_tokens(a_pred) ++ # character-base ++ gold_toks = list(normalize_answer(a_gold)) ++ pred_toks = list(normalize_answer(a_pred)) ++ + common = collections.Counter(gold_toks) & collections.Counter(pred_toks) + num_same = sum(common.values()) + if len(gold_toks) == 0 or len(pred_toks) == 0: +@@ -252,7 +259,7 @@ def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_ + return evaluation + + +-def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): ++def get_final_text(pred_text, orig_text, do_lower_case, tokenizer, verbose_logging=False): + """Project the tokenized prediction back to the original text.""" + + # When we created the data, we kept track of the alignment between original +@@ -295,7 +302,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): + # and `pred_text`, and check if they are the same length. If they are + # NOT the same length, the heuristic has failed. If they are the same + # length, we assume the characters are one-to-one aligned. +- tokenizer = BasicTokenizer(do_lower_case=do_lower_case) ++ # tokenizer = BasicTokenizer(do_lower_case=do_lower_case) + + tok_text = " ".join(tokenizer.tokenize(orig_text)) + +@@ -511,7 +518,15 @@ def compute_predictions_logits( + tok_text = " ".join(tok_text.split()) + orig_text = " ".join(orig_tokens) + +- final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging) ++ # final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging) ++ # When you use BertJapaneseTokenizer, the input text is not tokenized. ++ # So, we cannot use the get_final_text function ++ if tokenizer.__class__.__name__ == "BertJapaneseTokenizer": ++ final_text = "".join(tok_text.split(" ")) ++ else: ++ final_text = get_final_text(tok_text, orig_text, do_lower_case, tokenizer, verbose_logging) ++ # final_text = " ".join(tok_text.split(" ")) ++ + if final_text in seen_predictions: + continue + +diff --git a/src/transformers/data/processors/squad.py b/src/transformers/data/processors/squad.py +index cea84fb..ed9ec70 100644 +--- a/src/transformers/data/processors/squad.py ++++ b/src/transformers/data/processors/squad.py +@@ -113,7 +113,8 @@ def squad_convert_example_to_features( + + # If the answer cannot be found in the text, then skip this example. + actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)]) +- cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text)) ++ # cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text)) ++ cleaned_answer_text = example.answer_text + if actual_text.find(cleaned_answer_text) == -1: + logger.warning(f"Could not find answer: '{actual_text}' vs. '{cleaned_answer_text}'") + return [] diff --git a/fine-tuning/scripts/generate_results.py b/fine-tuning/scripts/generate_results.py new file mode 100644 index 0000000..655ee1a --- /dev/null +++ b/fine-tuning/scripts/generate_results.py @@ -0,0 +1,100 @@ +import sys +import io +import os +import argparse +import json +import csv + + +def get_input_string(input_data, task_type, additional_column_name_string): + input_strings = [] + if additional_column_name_string is not None: + for additional_column_name in additional_column_name_string.split(","): + input_strings.append(str(input_data[additional_column_name])) + + if task_type == "single-sentence": + input_strings.append(input_data["sentence"]) + elif task_type == "sentence-pair": + input_strings.extend([input_data["sentence1"], input_data["sentence2"]]) + elif task_type == "swag": + input_strings.extend([str(input_data["question"]), + input_data["choice0"], input_data["choice1"], input_data["choice2"], + input_data["choice3"], input_data["choice4"]]) + + return "\t".join(input_strings) + + +def get_eval_string(input_data, system_predict, + classification_type, task_type): + if classification_type == "regression": + return "{:.3f}".format(abs(float(system_predict) - float(input_data["label"]))) + elif classification_type == "classification": + if task_type == "swag": + return "CORRECT" if int(system_predict) == input_data["label"] else "WRONG" + else: + return "CORRECT" if system_predict == str(input_data["label"]) else "WRONG" + + +def print_header(args): + columns = [] + if args.additional_column_name_string is not None: + columns.extend(args.additional_column_name_string.split(",")) + if args.task_type == "single-sentence": + columns.append("sentence") + elif args.task_type == "sentence-pair": + columns.extend(["sentence1", "sentence2"]) + elif args.task_type == "swag": + columns.extend(["question", "choice0", "choice1", "choice2", "choice3", "choice4"]) + + columns.append("system") + columns.append("gold") + columns.append("eval") + + print("\t".join(columns)) + + +def main(args): + print_header(args) + + with open(args.system_predict_txt, "r", encoding="utf-8") as system_f: + next(system_f) + + with open(args.input_file, "r", encoding="utf-8") as input_f: + for row_system_predict, input in zip(csv.reader(system_f, delimiter="\t"), + input_f.readlines() if args.input_file_type == "json" + else csv.DictReader(input_f)): + input_data = None + if args.input_file_type == "json": + input_data = json.loads(input) + else: + input_data = input + + system_predict = row_system_predict[1] + + input_string = get_input_string(input_data, args.task_type, + args.additional_column_name_string) + eval_string = get_eval_string(input_data, system_predict, + args.classification_type, args.task_type) + # input system goal evaluation + print("{}\t{}\t{}\t{}".format(input_string, + system_predict, + input_data["label"], + eval_string + )) + + +if __name__ == "__main__": + sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') + sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') + + parser = argparse.ArgumentParser(description="print system result") + parser.add_argument("--system-predict-txt", type=str, default=None, required=True, help="system prediction file") + parser.add_argument("--input-file", type=str, default=None, required=True, help="input file") + parser.add_argument("--input-file-type", choices=["json", "csv"], default="json") + parser.add_argument("--task-type", choices=["single-sentence", "sentence-pair", "swag"], default="single-sentence") + parser.add_argument("--classification-type", choices=["classification", "regression"], default="classification") + parser.add_argument("--additional-column-name-string", type=str, default=None, help="additional column name string (comma separated)") + args = parser.parse_args() + + main(args)