-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
65 lines (60 loc) · 2.62 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import tensorflow as tf
import model
import utils
if __name__ == "__main__":
print('reading sequences...')
word_seqs, label_seqs = utils.read_seqs()
print('loading word2vec model...')
word_vec, word2index_dict = utils.get_word_data()
print('buiding trans probality matrix...')
label_trans_prob = utils.build_trans_prob(label_seqs)
print('buiding train data...')
train_data = utils.build_train_data(word2index_dict, word_seqs, label_seqs)
model_path = 'model'
sess = tf.Session()
model = model.Model(sess, word_vec, label_trans_prob)
sess.run(tf.global_variables_initializer())
print('start train model...')
model.train(train_data, 300)
'''model.save(model_path)
ckpt = tf.train.get_checkpoint_state(model_path)
model.saver.restore(sess, ckpt.model_checkpoint_path)
model.train()'''
'''test_data = utils.build_test_data(word2index_dict, '今天是个大晴天')
res = model.eval(test_data, label_trans_prob)
print([utils.index2label_dict[i] for i in res])'''
while True:
test_file_name = input('enter test file name: \n')
if test_file_name == 'exit':
exit()
result_file_name = 'result.utf8'
result_file = open(result_file_name, 'w', encoding='utf8')
test_file = open(test_file_name, 'r', encoding='utf8')
test_wordseq = []
sen_num = 0
for line in test_file:
line = line.strip('\n')
if line == '':
test_data = utils.build_test_data(word2index_dict, test_wordseq)
res_labelseq = model.eval(test_data, label_trans_prob)
res_labelseq = [utils.index2label_dict[i] for i in res_labelseq]
for word, label in zip(test_wordseq, res_labelseq):
result_file.write(word + ' ' + label +'\n')
result_file.write('\n')
sen_num += 1
print('eval No.' + str(sen_num) + ' sentence')
test_wordseq = []
else:
test_wordseq.append(line[0])
if len(test_wordseq) > 0:
test_data = utils.build_test_data(word2index_dict, test_wordseq)
res_labelseq = model.eval(test_data, label_trans_prob)
res_labelseq = [utils.index2label_dict[i] for i in res_labelseq]
for word, label in zip(test_wordseq, res_labelseq):
result_file.write(word + ' ' + label +'\n')
result_file.write('\n')
sen_num += 1
print('eval No.' + str(sen_num) + ' sentence')
test_wordseq = []
test_file.close()
result_file.close()