-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathEval_Trigger.py
75 lines (59 loc) · 2.58 KB
/
Eval_Trigger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import datetime, os, time
import numpy as np
import tensorflow as tf
from Dataset_Trigger import Dataset_Trigger as TRIGGER_DATASET
from Config import HyperParams_Tri_classification as hp
import nltk
def get_batch(sentence, word_id, max_sequence_length):
tokens = [word for word in nltk.word_tokenize(sentence)]
words = []
for i in range(max_sequence_length):
if i < len(tokens):
words.append(tokens[i])
else:
words.append('<eos>')
word_ids = []
for word in words:
if word in word_id:
word_ids.append(word_id[word])
else:
word_ids.append(word_id['<unk>'])
# print('word_ids :', word_ids)
size = len(word_ids)
x_batch = []
x_pos_batch = []
for i in range(size):
x_batch.append(word_ids)
x_pos_batch.append([j - i for j in range(size)])
return x_batch, x_pos_batch, tokens
if __name__ == '__main__':
dataset = TRIGGER_DATASET(batch_size=hp.batch_size, max_sequence_length=hp.max_sequence_length,
windows=hp.windows, dtype='IDENTIFICATION')
x_batch, x_pos_batch, token = get_batch(sentence = 'It could swell to as much as $500 billion if we go to war in Iraq',
word_id = dataset.word_id, max_sequence_length=hp.max_sequence_length)
print('x_batch :', x_batch)
print('x_pos_batch :', x_pos_batch)
checkpoint_dir = './runs/1542831140/checkpoints'
checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
graph = tf.Graph()
with graph.as_default():
sess = tf.Session()
with sess.as_default():
# Load the saved meta graph and restore variables
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess, checkpoint_file)
# Get the placeholders from the graph by name
input_x = graph.get_operation_by_name("input_x").outputs[0]
input_c_pos = graph.get_operation_by_name("input_c_pos").outputs[0]
dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]
# Tensors we want to evaluate
predictions = graph.get_operation_by_name("output/predicts").outputs[0]
feed_dict = {
input_x: x_batch,
input_c_pos: x_pos_batch,
dropout_keep_prob: 1.0,
}
preds = sess.run(predictions, feed_dict)
print('result!')
for i in range(len(preds)):
print('{}: {}'.format(dataset.id2word[x_batch[0][i]], preds[i]))