-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcompute_textgen_metrics.py
56 lines (42 loc) · 1.95 KB
/
compute_textgen_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import argparse
import evaluate
from datasets import load_dataset
from comet import download_model, load_from_checkpoint
from train_t5 import read_textfile
from sacrebleu import BLEU
from t5_utils import MODELS_DIR
DATA_PATH = 'data/{}/'
# MT-specific parameters
SRC = 'en'
TGT = 'de'
task_features = {'e2e_nlg_cleaned': ('meaning_representation', 'human_reference'), 'xsum': ('document', 'summary'),
'wmt22': ('source', 'reference')}
def main(args):
predictions = read_textfile(args.hyp)
split = 'test'
if args.task != 'wmt22':
test_data = load_dataset(args.task, split=split)
references = [datapoint[task_features[args.task][1]] for datapoint in test_data]
else:
data_dir = DATA_PATH.format(args.task) + '/{}-{}/'.format(SRC, TGT)
source_data = read_textfile('{}/{}.{}'.format(data_dir, split, SRC))
references = read_textfile('{}/{}.{}'.format(data_dir, split, TGT))
assert len(predictions) == len(references)
bleu = BLEU()
print(bleu.corpus_score(predictions, [references]))
if args.task == 'e2e_nlg_cleaned':
rouge = evaluate.load("rouge")
results = rouge.compute(predictions=predictions, references=references)
print(results)
elif args.task == 'wmt22':
model_path = download_model("Unbabel/wmt22-cometkiwi-da", saving_directory=MODELS_DIR)
model = load_from_checkpoint(model_path)
data = [{"src": x, "mt": y, "ref": z} for x, y, z in zip(source_data, predictions, references)]
model_output = model.predict(data, batch_size=200, gpus=1)
print("---COMET score: ", model_output.system_score)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Reproduce document-level metric scores from the paper.')
parser.add_argument('--task', default="wmt22", type=str, help='the task name')
parser.add_argument('--hyp', type=str, help='the hypothesis file')
args = parser.parse_args()
main(args)