forked from suriyadeepan/practical_seq2seq
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path03-Twitter-chatbot.py
45 lines (32 loc) · 1.14 KB
/
03-Twitter-chatbot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# In[1]:
import tensorflow as tf
import numpy as np
# preprocessed data
from datasets.twitter import data
import data_utils
# load data from pickle and npy files
metadata, idx_q, idx_a = data.load_data(PATH='datasets/twitter/')
(trainX, trainY), (testX, testY), (validX, validY) = data_utils.split_dataset(idx_q, idx_a)
# parameters
xseq_len = trainX.shape[-1]
yseq_len = trainY.shape[-1]
batch_size = 32
xvocab_size = len(metadata['idx2w'])
yvocab_size = xvocab_size
emb_dim = 1024
import seq2seq_wrapper
# In[7]:
model = seq2seq_wrapper.Seq2Seq(xseq_len=xseq_len,
yseq_len=yseq_len,
xvocab_size=xvocab_size,
yvocab_size=yvocab_size,
ckpt_path='ckpt/twitter/',
emb_dim=emb_dim,
num_layers=3
)
# In[8]:
val_batch_gen = data_utils.rand_batch_gen(validX, validY, 32)
train_batch_gen = data_utils.rand_batch_gen(trainX, trainY, batch_size)
# In[9]:
sess = model.restore_last_session()
sess = model.train(train_batch_gen, val_batch_gen)