-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluator.py
100 lines (81 loc) · 4.13 KB
/
evaluator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
The script that help evaluate the BLEU score of model, provided by TA from Mila IFT6759
(https://github.com/mila-iqia/ift6759/blob/master/projects/project2/tokenizer.py)
The BLEU score here used the implementation from sacreBLEU
(https://github.com/mjpost/sacreBLEU)
"""
import argparse
import subprocess
import tempfile
def generate_predictions(input_file_path: str, pred_file_path: str):
"""Generates predictions for the machine translation task (EN->FR).
You are allowed to modify this function as needed, but one again, you cannot
modify any other part of this file. We will be importing only this function
in our final evaluation script. Since you will most definitely need to import
modules for your code, you must import these inside the function itself.
Args:
input_file_path: the file path that contains the input data.
pred_file_path: the file path where to store the predictions.
Returns: None
"""
##### MODIFY BELOW #####
# Warp the test_evaluation.py as a function in here
import pickle
from Transformers_Google import Transformer
from evaluation import translate_batch
from dataloaders_processed import load_test_generator
root_path = "/project/cq-training-1/project2/submissions/team01/Low-Resources-Machine-Translation/"
transformer = Transformer(4, 256, 8, 1024, 20000, 20000, 20000, 20000, 0.1, None, None)
transformer.load_weights(root_path + "model_weights/transformers-weights")
print("Weights loaded in transformer")
input_tokenizer = pickle.load(open(root_path + "tokenizers/input_tokenizer.pkl", "rb"))
target_tokenizer = pickle.load(open(root_path + "tokenizers/target_tokenizer.pkl", "rb"))
print("Tokenizer loaded")
batch_size = 256
test_dataset = load_test_generator(input_file_path, input_tokenizer, batch_size)
print("Test generator prepared")
with open(pred_file_path, 'w', encoding='utf-8', buffering=1) as pred_file:
for batch, inp in enumerate(test_dataset):
if batch % 2 == 0:
print("Evaluating for batch", batch)
preds = translate_batch(inp, target_tokenizer, transformer, max_length_targ=120)
for p_fr in preds:
pred_file.write(p_fr.strip() + '\n')
# To silence exception in the end.
transformer = None
del transformer
##### MODIFY ABOVE #####
def compute_bleu(pred_file_path: str, target_file_path: str, print_all_scores: bool):
"""
Args:
pred_file_path: the file path that contains the predictions.
target_file_path: the file path that contains the targets (also called references).
print_all_scores: if True, will print one score per example.
Returns: None
"""
out = subprocess.run(
["sacrebleu", "--input", pred_file_path, target_file_path, '--tokenize', 'none', '--sentence-level',
'--score-only'], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
# print(out)
lines = out.stdout.split('\n')
if print_all_scores:
print('\n'.join(lines[:-1]))
else:
scores = [float(x) for x in lines[:-1]]
print('final avg bleu score: {:.2f}'.format(sum(scores) / len(scores)))
def main():
parser = argparse.ArgumentParser('script for evaluating a model.')
parser.add_argument('--target-file-path', help='path to target (reference) file', required=True)
parser.add_argument('--input-file-path', help='path to input file', required=True)
parser.add_argument('--print-all-scores', help='will print one score per sentence', action='store_true')
parser.add_argument('--do-not-run-model', help='will use --input-file-path as predictions, instead of running the '
'model on it', action='store_true')
args = parser.parse_args()
if args.do_not_run_model:
compute_bleu(args.input_file_path, args.target_file_path, args.print_all_scores)
else:
_, pred_file_path = tempfile.mkstemp()
generate_predictions(args.input_file_path, pred_file_path)
compute_bleu(pred_file_path, args.target_file_path, args.print_all_scores)
if __name__ == '__main__':
main()