Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I unable to test the model #55

Open
bandarikanth opened this issue Dec 26, 2017 · 1 comment
Open

I unable to test the model #55

bandarikanth opened this issue Dec 26, 2017 · 1 comment

Comments

@bandarikanth
Copy link

I used this code.
import tensorflow as tf
import numpy as np

preprocessed data

from datasets.twitter import data
import data_utils

load data from pickle and npy files

metadata, idx_q, idx_a = data.load_data(PATH='/home/kusuma/Videos/practical_seq2seq-master/datasets/twitter')
(trainX, trainY), (testX, testY), (validX, validY) = data_utils.split_dataset(idx_q, idx_a)

parameters

xseq_len = testX.shape[-1]
yseq_len = testY.shape[-1]
batch_size = 16
xvocab_size = len(metadata['idx2w'])
yvocab_size = xvocab_size
emb_dim = 1024

import seq2seq_wrapper

import importlib
importlib.reload(seq2seq_wrapper)

model = seq2seq_wrapper.Seq2Seq(xseq_len=xseq_len,
yseq_len=yseq_len,
xvocab_size=xvocab_size,
yvocab_size=yvocab_size,
ckpt_path='/home/kusuma/Videos/practical_seq2seq-master/ckpt/twitterseq2seq_model.ckpt-11000.data-00000-of-00001',
emb_dim=emb_dim,
num_layers=3
)

#val_batch_gen = data_utils.rand_batch_gen(validX, validY, 256)
test_batch_gen = data_utils.rand_batch_gen(testX, testY, 256)
#train_batch_gen = data_utils.rand_batch_gen(trainX, trainY, batch_size)

sess = model.test(test_batch_gen)
#sess = model.restore_last_session()

input_ = test_batch_gen.next()[0]
output = model.predict(sess, input_)
print(output.shape)
replies = []
for ii, oi in zip(input_.T, output):
q = data_utils.decode(sequence=ii, lookup=metadata['idx2w'], separator=' ')
decoded = data_utils.decode(sequence=oi, lookup=metadata['idx2w'], separator=' ').split(' ')
if decoded.count('unk') == 0:
if decoded not in replies:
print('q : [{0}]; a : [{1}]'.format(q, ' '.join(decoded)))
replies.append(decoded)

I got errors:
usr/bin/python3.5 /home/kusuma/Videos/practical_seq2seq-master/test.py
Traceback (most recent call last):
File "/home/kusuma/Videos/practical_seq2seq-master/test.py", line 42, in
Building Graph output = model.predict(sess, input_)
File "/home/kusuma/Videos/practical_seq2seq-master/seq2seq_wrapper.py", line 175, in predict
dec_op_v = sess.run(self.decode_outputs_test, feed_dict)
AttributeError: 'NoneType' object has no attribute 'run'

@PedroPei
Copy link

Oh,I see.You didn't load your model to sess before you run it.
You need to save your model after training using saver and load the model before you test it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants