-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
60 lines (38 loc) · 1.97 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import csv
import numpy as np
from utils.generation import iterate_geneator, score_normalized, score_ratio
from utils.data import syn_data
from utils.get_prompt import syn_defs
'''
generation times = 2*N(prompt)*N(synonyms)*LEN(tokenized(synonyms))
'''
output_file = "./output/p4.tsv"
fixed_prompt=False
with open(output_file, 'w', encoding='utf8', newline='') as tsv_file:
tsv_writer = csv.writer(tsv_file, delimiter='\t', lineterminator='\n')
tsv_writer.writerow(["synonyms", "scores_gpt", "scores_biogpt", "un_normalized_score", "normalized_score" ])
if fixed_prompt==False:
ind = 0
for syn_def, synonyms in zip(syn_defs, syn_data):
prompt = "What is: " + syn_def + "?"
scores_gpt = iterate_geneator("gpt2", prompt=prompt, synonyms=synonyms)
scores_biogpt = iterate_geneator("microsoft/biogpt", prompt=prompt, synonyms=synonyms)
np.set_printoptions(precision=3)
un_normalized_score = score_ratio(scores_gpt , scores_biogpt)
normalized_score = score_ratio(score_normalized(scores_gpt) , score_normalized(scores_biogpt))
tsv_writer.writerow([synonyms, scores_gpt, scores_biogpt, un_normalized_score, normalized_score])
ind +=1
print(ind)
else:
# prompt = "I have " p1
# prompt = "The easiest term is " p2
ind = 0
for synonyms in syn_data:
scores_gpt = iterate_geneator("gpt2", prompt=prompt, synonyms=synonyms)
scores_biogpt = iterate_geneator("microsoft/biogpt", prompt=prompt, synonyms=synonyms)
np.set_printoptions(precision=3)
un_normalized_score = score_ratio(scores_gpt , scores_biogpt)
normalized_score = score_ratio(score_normalized(scores_gpt) , score_normalized(scores_biogpt))
tsv_writer.writerow([synonyms, scores_gpt, scores_biogpt, un_normalized_score, normalized_score])
ind +=1
print(ind)