-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
102 lines (88 loc) · 3.94 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
from collections import defaultdict
from typing import Callable, Dict, Iterable, List, Tuple, Union
import numpy as np
# from rouge_score import rouge_scorer, scoring
import nltk
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
def line_normalize(line: str):
line = " ".join(line.strip().split())
return line
def calculate_bleu(ref_lines, gen_lines, metrics: dict = None):
if metrics is None:
metrics = {}
for bleu_i in range(1, 5):
weights = tuple([1. / bleu_i for _ in range(bleu_i)])
metrics[f"bleu-{bleu_i}"] = round(nltk.translate.bleu_score.corpus_bleu(
list_of_references=[[ref] for ref in ref_lines],
hypotheses=gen_lines,
weights=weights), 4)
return metrics
def extract_rouge_mid_statistics(dct):
new_dict = {}
for k1, v1 in dct.items():
mid = v1.mid
new_dict[k1] = {stat: round(getattr(mid, stat), 4) for stat in ["precision", "recall", "fmeasure"]}
return new_dict
# def calculate_rouge(
# pred_lines: List[str],
# tgt_lines: List[str],
# use_stemmer=True,
# rouge_keys=ROUGE_KEYS,
# return_precision_and_recall=False,
# bootstrap_aggregation=True,
# newline_sep=True,
# ) -> Dict:
# """Calculate rouge using rouge_scorer package.
# Args:
# pred_lines: list of summaries generated by model
# tgt_lines: list of groundtruth summaries (e.g. contents of val.target)
# use_stemmer: Bool indicating whether Porter stemmer should be used to
# strip word suffixes to improve matching.
# rouge_keys: which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum
# return_precision_and_recall: (False) whether to also return precision and recall.
# bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False
# this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]``
# newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL
# on multi sentence summaries (CNN/DM dataset).
# Returns:
# Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys
# """
# scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer)
# aggregator = scoring.BootstrapAggregator()
# for pred, tgt in zip(tgt_lines, pred_lines):
# # rougeLsum expects "\n" separated sentences within a summary
# if newline_sep:
# pred = pred + "\n"
# tgt = tgt + "\n"
# scores = scorer.score(pred, tgt)
# aggregator.add_scores(scores)
# if bootstrap_aggregation:
# result = aggregator.aggregate()
# if return_precision_and_recall:
# return extract_rouge_mid_statistics(result) # here we return dict
# else:
# return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
# else:
# return aggregator._scores # here we return defaultdict(list)
def repetition_distinct_metric(gen_lines, metrics: dict = None, repetition_times=2):
if metrics is None:
metrics = {}
for gram_n in range(1, 5):
repetition_count = 0
all_ngram = defaultdict(int)
all_ngram_num = 0
for gen_idx, line in enumerate(gen_lines):
n_grams = ["_".join(gram) for gram in nltk.ngrams(line, n=gram_n)]
all_ngram_num += len(n_grams)
# for distinct
for gram in n_grams:
all_ngram[gram] += 1
# for repetition
for gram in set(n_grams):
if n_grams.count(gram) >= repetition_times:
repetition_count += 1
break
metrics[f"repetition-{gram_n}"] = "%.4f" % (repetition_count / float(len(gen_lines)))
metrics[f"distinct-{gram_n}"] = "%.4f" % (len(all_ngram) / float(all_ngram_num))
return metrics