forked from thunlp/OpenKE
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
71 lines (57 loc) · 2.12 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import config
import models
import tensorflow as tf
import numpy as np
import json
import os
os.environ['CUDA_VISIBLE_DEVICES']='0'
os.environ['CONDA_PREFIX']=''
def main():
kb = 'RV15M' # FB15K
model = "TransE" # DistMult
clusters = "V2"
con = config.Config()
con.set_in_path("/u/wujieche/Projects/OpenKE/data/"+kb+"/")
con.set_test_link_prediction(True)
con.set_test_triple_classification(True)
con.set_work_threads(8)
con.set_dimension(100)
con.set_import_files("models/{}-{}-{}_model.vec.tf".format(model, kb, clusters))
con.init()
con.set_model(getattr(models, model)) # models.TransE
con.test(clusters)
log_path = "{}-{}-{}_clusters_test_log.log".format(model, kb, clusters)
os.rename("log.log", log_path)
# rank_by_relID_plot()
def read_log(path):
log = []
with open(path) as f:
for l in f:
arg1, rel, arg2, head_rank, tail_rank = l.split()
log.append([rel, head_rank, tail_rank])
return log
def rank_by_relID_plot(log_path):
log = read_log(log_path)
import matplotlib.pyplot as plt
head_ranks = [hr for r, hr, _ in log]
tail_ranks = [tr for r, _, tr in log]
ids = [r for r, _, __ in log]
# FIXME Make a pandas dataframe directly out of the log data
# (because pandas has nice data processing functions useful for graphing)
# df = pd.read_csv("log_no_cluster.csv")
# df = df.drop(columns=["id", "index", "h", "t"])
# max_r = max(df['r'])
# # print(max_r)
# df['ranked_ids'] = df['r'].rank(method='first')
# # df['x'] = pd.qcut(df['ranked_ids'], 1000)
freq_head_MR, freq_tail_MR = 169512, 237961
darkblue, darkgreen = "#3030AA", "#40AA40"
# bins = 30
# grp = df.groupby(by = pd.qcut(df['r'], bins))
# df = grp.aggregate(np.average)
plt.plot(ids, tail_ranks, "b", label = "arg1 mean rank")
plt.plot(ids, head_ranks, "g", label = "arg2 mean rank")
plt.hlines([freq_head_MR, freq_tail_MR], xmin=0, xmax = 10000, colors = [darkblue, darkgreen], label = "baseline")
plt.legend()
plt.savefig(log_path[:-4]+".MR_by_ID-graph.png")
main()