-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexcute.py
153 lines (122 loc) · 4.49 KB
/
excute.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math
import pickle
import numpy as np
import sys
from vocabulary import PAD_token, SOS_token, EOS_token, Voc
from model import EncoderRNN, LuongAttnDecoderRNN, GreedySearchDecoder, trainIters, evaluateInput
USE_CUDA = torch.cuda.is_available()
device = torch.device('cuda' if USE_CUDA else 'cpu')
MAX_LENGTH = 15 # Maximun sentence length to consider
#
# # load voc and pairs
def loadDataset():
with open(os.path.join('data', 'voc.pkl'), 'rb') as handle_voc:
voc = pickle.load(handle_voc)
with open(os.path.join('data', 'pairs.pkl'), 'rb') as handle_pairs:
pairs = pickle.load(handle_pairs)
return voc, pairs
voc, pairs = loadDataset()
save_dir = os.path.join('data', 'save')
corpus_name = 'cornell movie-dialogs corpus'
corpus = os.path.join('data', corpus_name)
# 配置模型
model_name = 'cb_model'
# attn_model = 'dot'
attn_model = 'general'
# attn_model = 'concat'
hidden_size = 256
encoder_n_layers = 3
decoder_n_layers = 3
dropout = 0.1
batch_size = 64
# 配置超参数和优化器
clip = 5.0
teacher_forcing_ratio = 0.5
learning_rate = 0.0001
decoder_learning_ratio = 5.0
n_iteration = 4000
checkpoint_iter = 4000 # 从哪个checkpoint恢复
print_every = 10
save_every = 1000
# 初始设为None,从头开始训练
# 此模型未添加断点续训
loadFilename = None
print('Building encoder and decoder ...')
# Initializing word embedding
embedding = nn.Embedding(voc.num_words, hidden_size)
# initializing encoder and decoder
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
# use device
encoder = encoder.to(device)
decoder = decoder.to(device)
# 设置进入训练模式,从而开启dropout
encoder.train()
decoder.train()
# 初始化优化器
print('Building optimizers ...')
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
# 从checkpoint载入训练保存模型
def loadCheckpoint():
# if loadFilename
loadFilename = os.path.join(save_dir, model_name, corpus_name,
'{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),
'{}_checkpoint.tar'.format(checkpoint_iter))
# 如果loadFilename不空,则从中加载模型
if loadFilename:
# 如果训练和加载是一条机器,那么直接加载
checkpoint = torch.load(loadFilename)
# 否则比如checkpoint是在GPU上得到的,但是我们现在又用CPU来训练或者测试,那么注释掉下面的代码
# checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
encoder_sd = checkpoint['en']
decoder_sd = checkpoint['de']
encoder_optimizer_sd = checkpoint['en_opt']
decoder_optimizer_sd = checkpoint['de_opt']
embedding_sd = checkpoint['embedding']
voc.__dict__ = checkpoint['voc_dict']
embedding.load_state_dict(embedding_sd)
encoder.load_state_dict(encoder_sd)
decoder.load_state_dict(decoder_sd)
encoder_optimizer.load_state_dict(encoder_optimizer_sd)
decoder_optimizer.load_state_dict(decoder_optimizer_sd)
#
# # training
def train():
trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,
embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size,
print_every, save_every, clip, corpus_name, loadFilename, hidden_size, teacher_forcing_ratio)
#
# # testing
def test():
# loadFilename载入checkpoint
loadCheckpoint()
encoder.eval()
decoder.eval()
searcher = GreedySearchDecoder(encoder, decoder)
evaluateInput(encoder, decoder, searcher, voc)
# train() and test()
def main(argv):
if argv[1] == 'train':
train()
elif argv[1] == 'test':
test()
if __name__ == "__main__":
main(sys.argv)