Skip to content

Commit

Permalink
using beam_search during inferring
Browse files Browse the repository at this point in the history
  • Loading branch information
jingxil committed Jul 21, 2017
1 parent c04efd8 commit bc3c656
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 36 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ training

$ python convex_hull.py --ARG=VALUE

evaluating

$ python convex_hull.py --forward_only=True --beam_width=VALUE --ARG=VALUE

visualizing

$ tensorboard --logdir=DIR
Expand Down
Binary file modified README/loss-5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 12 additions & 2 deletions convex_hull.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
tf.app.flags.DEFINE_integer("rnn_size", 128, "RNN unit size.")
tf.app.flags.DEFINE_integer("attention_size", 128, "Attention size.")
tf.app.flags.DEFINE_integer("num_layers", 1, "Number of layers.")
tf.app.flags.DEFINE_integer("beam_width", 2, "Width of beam search .")
tf.app.flags.DEFINE_float("learning_rate", 0.001, "Learning rate.")
tf.app.flags.DEFINE_float("max_gradient_norm", 5.0, "Maximum gradient norm.")
tf.app.flags.DEFINE_boolean("forward_only", False, "Forward Only.")
Expand Down Expand Up @@ -90,7 +91,8 @@ def build_model(self):
max_output_sequence_len=FLAGS.max_output_sequence_len,
rnn_size=FLAGS.rnn_size,
attention_size=FLAGS.attention_size,
num_layers=FLAGS.num_layers,
num_layers=FLAGS.num_layers,
beam_width=FLAGS.beam_width,
learning_rate=FLAGS.learning_rate,
max_gradient_norm=FLAGS.max_gradient_norm,
forward_only=self.forward_only)
Expand Down Expand Up @@ -145,7 +147,15 @@ def train(self):
step_time, loss = 0.0, 0.0

def eval(self):
pass
""" Randomly get a batch of data and output predictions """
inputs,enc_input_weights, outputs, dec_input_weights = self.get_batch()
predicted_ids = self.model.step(self.sess, inputs, enc_input_weights)
print("="*20)
for i in range(FLAGS.batch_size):
print("* %dth sample target: %s" % (i,str(outputs[i,1:]-2)))
for predict in predicted_ids[i]:
print("prediction: "+str(predict))
print("="*20)

def run(self):
if self.forward_only:
Expand Down
127 changes: 93 additions & 34 deletions pointer_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,34 @@
PAD_ID=1
END_ID=2

class PointerWrapper(tf.contrib.seq2seq.AttentionWrapper):
"""Customized AttentionWrapper for PointerNet."""

def __init__(self,cell,attention_size,memory,initial_cell_state=None,name=None):
# In the paper, Bahdanau Attention Mechanism is used
# We want the scores rather than the probabilities of alignments
# Hence, we customize the probability_fn to return scores directly
attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(attention_size, memory, probability_fn=lambda x: x )
# According to the paper, no need to concatenate the input and attention
# Therefore, we make cell_input_fn to return input only
cell_input_fn=lambda input, attention: input
# Call super __init__
super(PointerWrapper, self).__init__(cell,
attention_mechanism=attention_mechanism,
attention_layer_size=None,
alignment_history=False,
cell_input_fn=cell_input_fn,
output_attention=True,
initial_cell_state=initial_cell_state,
name=name)
@property
def output_size(self):
return self.state_size.alignments

def call(self, inputs, state):
_, next_state = super(PointerWrapper, self).call(inputs, state)
return next_state.alignments, next_state


class PointerNet(object):
""" Pointer Net Model
Expand All @@ -15,7 +42,9 @@ class PointerNet(object):
https://arxiv.org/abs/1506.03134.
"""

def __init__(self, batch_size=128, max_input_sequence_len=5, max_output_sequence_len=7, rnn_size=128, attention_size=128, num_layers=2, learning_rate=0.001, max_gradient_norm=5, forward_only=False):
def __init__(self, batch_size=128, max_input_sequence_len=5, max_output_sequence_len=7,
rnn_size=128, attention_size=128, num_layers=2, beam_width=2,
learning_rate=0.001, max_gradient_norm=5, forward_only=False):
"""Create the model.
Args:
Expand All @@ -25,6 +54,7 @@ def __init__(self, batch_size=128, max_input_sequence_len=5, max_output_sequence
rnn_size: the size of each RNN hidden units
attention_size: the size of dimensions in attention mechanism
num_layers: the number of stacked RNN layers
beam_width: the width of beam search
learning_rate: the initial learning rate during training
max_gradient_norm: gradients will be clipped to maximally this norm.
forward_only: whether the model is forwarding only
Expand Down Expand Up @@ -68,15 +98,15 @@ def __init__(self, batch_size=128, max_input_sequence_len=5, max_output_sequence
# Shape: batch_size*[max_output_sequence_len,1]
dec_input_ids = tf.unstack(tf.expand_dims(tf.stack(outputs_list[:-1],axis=1),2),axis=0)
# encoder input ids
# Shape: batch_size*[vocab_size,1]
# Shape: batch_size*[max_input_sequence_len+1,1]
enc_input_ids = [tf.expand_dims(tf.range(2,self.vocab_size),1)]*self.batch_size
# Look up encoder and decoder inputs
encoder_inputs = []
decoder_inputs = []
for i in range(self.batch_size):
encoder_inputs.append(tf.gather_nd(embedding_table_list[i], enc_input_ids[i]))
decoder_inputs.append(tf.gather_nd(embedding_table_list[i], dec_input_ids[i]))
# Shape: [batch_size,vocab_size,2]
# Shape: [batch_size,max_input_sequence_len+1,2]
encoder_inputs = tf.stack(encoder_inputs,axis=0)
# Shape: [batch_size,max_output_sequence_len,2]
decoder_inputs = tf.stack(decoder_inputs,axis=0)
Expand All @@ -86,30 +116,56 @@ def __init__(self, batch_size=128, max_input_sequence_len=5, max_output_sequence
bw_enc_cell = tf.contrib.rnn.MultiRNNCell([cell(rnn_size) for _ in range(num_layers)])
else:
fw_enc_cell = cell(rnn_size)
bw_enc_cell = cell(rnn_size)
# Encode input to obtain the inital state for decoder
#_,enc_states = tf.nn.dynamic_rnn(enc_cell,encoder_inputs,enc_input_lens,dtype=tf.float32)
bw_enc_cell = cell(rnn_size)
# Tile inputs if forward only
if self.forward_only:
# Tile encoder_inputs and enc_input_lens
encoder_inputs = tf.contrib.seq2seq.tile_batch(encoder_inputs,beam_width)
enc_input_lens = tf.contrib.seq2seq.tile_batch(enc_input_lens,beam_width)
# Encode input to obtain memory for later queries
memory,_ = tf.nn.bidirectional_dynamic_rnn(fw_enc_cell, bw_enc_cell, encoder_inputs, enc_input_lens, dtype=tf.float32)
# Shape: [batch_size, max_input_sequence_len+1, 2*rnn_size]
# Shape: [batch_size(*beam_width), max_input_sequence_len+1, 2*rnn_size]
memory = tf.concat(memory, 2)
# Choose Attention Mechanism
# We want the scores rather than the probabilities of alignments
# Hence, we customize the probability_fn to return scores directly
attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(attention_size, memory, probability_fn=lambda x: x )
# Build attention cell
# According to the paper, no need to concatenate the input and attention
# Therefore, we make cell_input_fn to return input only
attn_cell = tf.contrib.seq2seq.AttentionWrapper(cell(rnn_size), attention_mechanism, alignment_history=True,
cell_input_fn=lambda input, attention: input)
# PointerWrapper
pointer_cell = PointerWrapper(cell(rnn_size), attention_size, memory)
# Stack decoder cells if needed
if num_layers > 1:
dec_cell = tf.contrib.rnn.MultiRNNCell([cell(rnn_size) for _ in range(num_layers-1)]+[attn_cell])
dec_cell = tf.contrib.rnn.MultiRNNCell([cell(rnn_size) for _ in range(num_layers-1)]+[pointer_cell])
else:
dec_cell = attn_cell

dec_cell = pointer_cell
# Different decoding scenario
if self.forward_only:
raise NotImplementedError()
# Tile embedding_table
tile_embedding_table = tf.tile(tf.expand_dims(embedding_table,1),[1,beam_width,1,1])
# Customize embedding_lookup_fn
def embedding_lookup(ids):
# Note the output value of the decoder only ranges 0 to max_input_sequence_len
# while embedding_table contains two more tokens' values
# To get around this, shift ids
# Shape: [batch_size,beam_width]
ids = ids+2
# Shape: [batch_size,beam_width,vocab_size]
one_hot_ids = tf.cast(tf.one_hot(ids,self.vocab_size), dtype=tf.float32)
# Shape: [batch_size,beam_width,vocab_size,1]
one_hot_ids = tf.expand_dims(one_hot_ids,-1)
# Shape: [batch_size,beam_width,features_size]
next_inputs = tf.reduce_sum(one_hot_ids*tile_embedding_table, axis=2)
return next_inputs
# Do a little trick so that we can use 'BeamSearchDecoder'
shifted_START_ID = START_ID - 2
shifted_END_ID = END_ID - 2
# Beam Search Decoder
decoder = tf.contrib.seq2seq.BeamSearchDecoder(dec_cell, embedding_lookup,
tf.tile([shifted_START_ID],[self.batch_size]), shifted_END_ID,
dec_cell.zero_state(self.batch_size*beam_width,tf.float32), beam_width)
# Decode
outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder)
# predicted_ids
# Shape: [batch_size, max_output_sequence_len, beam_width]
predicted_ids = outputs.predicted_ids
# Transpose predicted_ids
# Shape: [batch_size, beam_width, max_output_sequence_len]
self.predicted_ids = tf.transpose(predicted_ids,[0,2,1])
else:
# Get the maximum sequence length in current batch
cur_batch_max_len = tf.reduce_max(dec_input_lens)
Expand All @@ -118,15 +174,10 @@ def __init__(self, batch_size=128, max_input_sequence_len=5, max_output_sequence
# Basic Decoder
decoder = tf.contrib.seq2seq.BasicDecoder(dec_cell, helper, dec_cell.zero_state(self.batch_size,tf.float32))
# Decode
_, states, _ = tf.contrib.seq2seq.dynamic_decode(decoder,impute_finished=True)
# Locate the cell state containing logits
if num_layers>1:
state=states[-1]
else:
state=states
# Fetch logits
# Shape: [batch_size,cur_batch_max_len,max_input_sequence_len+1]
logits = tf.transpose(state.alignment_history.stack(), [1,0,2])
outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder,impute_finished=True)
# logits
logits = outputs.rnn_output
# predicted_ids_with_logits
self.predicted_ids_with_logits=tf.nn.top_k(logits)
# Pad logits to the same shape as targets
logits = tf.concat([logits,tf.ones([self.batch_size,self.max_output_sequence_len-cur_batch_max_len,self.max_input_sequence_len+1])],axis=1)
Expand All @@ -149,8 +200,6 @@ def __init__(self, batch_size=128, max_input_sequence_len=5, max_output_sequence
optimizer = tf.train.AdamOptimizer(self.init_learning_rate)
# Update operator
self.update = optimizer.apply_gradients(zip(clipped_gradients, parameters),global_step=self.global_step)
# Saver
self.saver = tf.train.Saver(tf.global_variables())
# Summarize
tf.summary.scalar('loss',self.loss)
for p in parameters:
Expand All @@ -163,6 +212,9 @@ def __init__(self, batch_size=128, max_input_sequence_len=5, max_output_sequence
self.debug_var = logits
#/DEBUG PART

# Saver
self.saver = tf.train.Saver(tf.global_variables())

def step(self, session, inputs, enc_input_weights, outputs=None, dec_input_weights=None):
"""Run a step of the model feeding the given inputs.
Expand All @@ -174,8 +226,15 @@ def step(self, session, inputs, enc_input_weights, outputs=None, dec_input_weigh
dec_input_weights: the weights of decoder input points. shape: [batch_size,max_output_sequence_len]
Returns:
A triple
(training)
The summary
The total loss
The predicted ids with logits
The targets
The variable for debugging
(evaluation)
The predicted ids
"""
#Fill up inputs
input_feed = {}
Expand All @@ -187,7 +246,7 @@ def step(self, session, inputs, enc_input_weights, outputs=None, dec_input_weigh

#Fill up outputs
if self.forward_only:
raise NotImplementedError()
output_feed = [self.predicted_ids]
else:
output_feed = [self.update, self.summary_op, self.loss, self.predicted_ids_with_logits, self.shifted_targets, self.debug_var]

Expand All @@ -197,6 +256,6 @@ def step(self, session, inputs, enc_input_weights, outputs=None, dec_input_weigh

#Return
if self.forward_only:
pass
return outputs[0]
else:
return outputs[1],outputs[2],outputs[3],outputs[4],outputs[5]

0 comments on commit bc3c656

Please sign in to comment.