diff --git a/README.md b/README.md index ebff558..97879ba 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,10 @@ training $ python convex_hull.py --ARG=VALUE +evaluating + +$ python convex_hull.py --forward_only=True --beam_width=VALUE --ARG=VALUE + visualizing $ tensorboard --logdir=DIR diff --git a/README/loss-5.png b/README/loss-5.png index b6d506a..67ee92e 100644 Binary files a/README/loss-5.png and b/README/loss-5.png differ diff --git a/convex_hull.py b/convex_hull.py index 8a919b2..3156e73 100644 --- a/convex_hull.py +++ b/convex_hull.py @@ -10,6 +10,7 @@ tf.app.flags.DEFINE_integer("rnn_size", 128, "RNN unit size.") tf.app.flags.DEFINE_integer("attention_size", 128, "Attention size.") tf.app.flags.DEFINE_integer("num_layers", 1, "Number of layers.") +tf.app.flags.DEFINE_integer("beam_width", 2, "Width of beam search .") tf.app.flags.DEFINE_float("learning_rate", 0.001, "Learning rate.") tf.app.flags.DEFINE_float("max_gradient_norm", 5.0, "Maximum gradient norm.") tf.app.flags.DEFINE_boolean("forward_only", False, "Forward Only.") @@ -90,7 +91,8 @@ def build_model(self): max_output_sequence_len=FLAGS.max_output_sequence_len, rnn_size=FLAGS.rnn_size, attention_size=FLAGS.attention_size, - num_layers=FLAGS.num_layers, + num_layers=FLAGS.num_layers, + beam_width=FLAGS.beam_width, learning_rate=FLAGS.learning_rate, max_gradient_norm=FLAGS.max_gradient_norm, forward_only=self.forward_only) @@ -145,7 +147,15 @@ def train(self): step_time, loss = 0.0, 0.0 def eval(self): - pass + """ Randomly get a batch of data and output predictions """ + inputs,enc_input_weights, outputs, dec_input_weights = self.get_batch() + predicted_ids = self.model.step(self.sess, inputs, enc_input_weights) + print("="*20) + for i in range(FLAGS.batch_size): + print("* %dth sample target: %s" % (i,str(outputs[i,1:]-2))) + for predict in predicted_ids[i]: + print("prediction: "+str(predict)) + print("="*20) def run(self): if self.forward_only: diff --git a/pointer_net.py b/pointer_net.py index 045105c..5021d13 100644 --- a/pointer_net.py +++ b/pointer_net.py @@ -4,7 +4,34 @@ PAD_ID=1 END_ID=2 +class PointerWrapper(tf.contrib.seq2seq.AttentionWrapper): + """Customized AttentionWrapper for PointerNet.""" + def __init__(self,cell,attention_size,memory,initial_cell_state=None,name=None): + # In the paper, Bahdanau Attention Mechanism is used + # We want the scores rather than the probabilities of alignments + # Hence, we customize the probability_fn to return scores directly + attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(attention_size, memory, probability_fn=lambda x: x ) + # According to the paper, no need to concatenate the input and attention + # Therefore, we make cell_input_fn to return input only + cell_input_fn=lambda input, attention: input + # Call super __init__ + super(PointerWrapper, self).__init__(cell, + attention_mechanism=attention_mechanism, + attention_layer_size=None, + alignment_history=False, + cell_input_fn=cell_input_fn, + output_attention=True, + initial_cell_state=initial_cell_state, + name=name) + @property + def output_size(self): + return self.state_size.alignments + + def call(self, inputs, state): + _, next_state = super(PointerWrapper, self).call(inputs, state) + return next_state.alignments, next_state + class PointerNet(object): """ Pointer Net Model @@ -15,7 +42,9 @@ class PointerNet(object): https://arxiv.org/abs/1506.03134. """ - def __init__(self, batch_size=128, max_input_sequence_len=5, max_output_sequence_len=7, rnn_size=128, attention_size=128, num_layers=2, learning_rate=0.001, max_gradient_norm=5, forward_only=False): + def __init__(self, batch_size=128, max_input_sequence_len=5, max_output_sequence_len=7, + rnn_size=128, attention_size=128, num_layers=2, beam_width=2, + learning_rate=0.001, max_gradient_norm=5, forward_only=False): """Create the model. Args: @@ -25,6 +54,7 @@ def __init__(self, batch_size=128, max_input_sequence_len=5, max_output_sequence rnn_size: the size of each RNN hidden units attention_size: the size of dimensions in attention mechanism num_layers: the number of stacked RNN layers + beam_width: the width of beam search learning_rate: the initial learning rate during training max_gradient_norm: gradients will be clipped to maximally this norm. forward_only: whether the model is forwarding only @@ -68,7 +98,7 @@ def __init__(self, batch_size=128, max_input_sequence_len=5, max_output_sequence # Shape: batch_size*[max_output_sequence_len,1] dec_input_ids = tf.unstack(tf.expand_dims(tf.stack(outputs_list[:-1],axis=1),2),axis=0) # encoder input ids - # Shape: batch_size*[vocab_size,1] + # Shape: batch_size*[max_input_sequence_len+1,1] enc_input_ids = [tf.expand_dims(tf.range(2,self.vocab_size),1)]*self.batch_size # Look up encoder and decoder inputs encoder_inputs = [] @@ -76,7 +106,7 @@ def __init__(self, batch_size=128, max_input_sequence_len=5, max_output_sequence for i in range(self.batch_size): encoder_inputs.append(tf.gather_nd(embedding_table_list[i], enc_input_ids[i])) decoder_inputs.append(tf.gather_nd(embedding_table_list[i], dec_input_ids[i])) - # Shape: [batch_size,vocab_size,2] + # Shape: [batch_size,max_input_sequence_len+1,2] encoder_inputs = tf.stack(encoder_inputs,axis=0) # Shape: [batch_size,max_output_sequence_len,2] decoder_inputs = tf.stack(decoder_inputs,axis=0) @@ -86,30 +116,56 @@ def __init__(self, batch_size=128, max_input_sequence_len=5, max_output_sequence bw_enc_cell = tf.contrib.rnn.MultiRNNCell([cell(rnn_size) for _ in range(num_layers)]) else: fw_enc_cell = cell(rnn_size) - bw_enc_cell = cell(rnn_size) - # Encode input to obtain the inital state for decoder - #_,enc_states = tf.nn.dynamic_rnn(enc_cell,encoder_inputs,enc_input_lens,dtype=tf.float32) + bw_enc_cell = cell(rnn_size) + # Tile inputs if forward only + if self.forward_only: + # Tile encoder_inputs and enc_input_lens + encoder_inputs = tf.contrib.seq2seq.tile_batch(encoder_inputs,beam_width) + enc_input_lens = tf.contrib.seq2seq.tile_batch(enc_input_lens,beam_width) # Encode input to obtain memory for later queries memory,_ = tf.nn.bidirectional_dynamic_rnn(fw_enc_cell, bw_enc_cell, encoder_inputs, enc_input_lens, dtype=tf.float32) - # Shape: [batch_size, max_input_sequence_len+1, 2*rnn_size] + # Shape: [batch_size(*beam_width), max_input_sequence_len+1, 2*rnn_size] memory = tf.concat(memory, 2) - # Choose Attention Mechanism - # We want the scores rather than the probabilities of alignments - # Hence, we customize the probability_fn to return scores directly - attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(attention_size, memory, probability_fn=lambda x: x ) - # Build attention cell - # According to the paper, no need to concatenate the input and attention - # Therefore, we make cell_input_fn to return input only - attn_cell = tf.contrib.seq2seq.AttentionWrapper(cell(rnn_size), attention_mechanism, alignment_history=True, - cell_input_fn=lambda input, attention: input) + # PointerWrapper + pointer_cell = PointerWrapper(cell(rnn_size), attention_size, memory) # Stack decoder cells if needed if num_layers > 1: - dec_cell = tf.contrib.rnn.MultiRNNCell([cell(rnn_size) for _ in range(num_layers-1)]+[attn_cell]) + dec_cell = tf.contrib.rnn.MultiRNNCell([cell(rnn_size) for _ in range(num_layers-1)]+[pointer_cell]) else: - dec_cell = attn_cell - + dec_cell = pointer_cell + # Different decoding scenario if self.forward_only: - raise NotImplementedError() + # Tile embedding_table + tile_embedding_table = tf.tile(tf.expand_dims(embedding_table,1),[1,beam_width,1,1]) + # Customize embedding_lookup_fn + def embedding_lookup(ids): + # Note the output value of the decoder only ranges 0 to max_input_sequence_len + # while embedding_table contains two more tokens' values + # To get around this, shift ids + # Shape: [batch_size,beam_width] + ids = ids+2 + # Shape: [batch_size,beam_width,vocab_size] + one_hot_ids = tf.cast(tf.one_hot(ids,self.vocab_size), dtype=tf.float32) + # Shape: [batch_size,beam_width,vocab_size,1] + one_hot_ids = tf.expand_dims(one_hot_ids,-1) + # Shape: [batch_size,beam_width,features_size] + next_inputs = tf.reduce_sum(one_hot_ids*tile_embedding_table, axis=2) + return next_inputs + # Do a little trick so that we can use 'BeamSearchDecoder' + shifted_START_ID = START_ID - 2 + shifted_END_ID = END_ID - 2 + # Beam Search Decoder + decoder = tf.contrib.seq2seq.BeamSearchDecoder(dec_cell, embedding_lookup, + tf.tile([shifted_START_ID],[self.batch_size]), shifted_END_ID, + dec_cell.zero_state(self.batch_size*beam_width,tf.float32), beam_width) + # Decode + outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder) + # predicted_ids + # Shape: [batch_size, max_output_sequence_len, beam_width] + predicted_ids = outputs.predicted_ids + # Transpose predicted_ids + # Shape: [batch_size, beam_width, max_output_sequence_len] + self.predicted_ids = tf.transpose(predicted_ids,[0,2,1]) else: # Get the maximum sequence length in current batch cur_batch_max_len = tf.reduce_max(dec_input_lens) @@ -118,15 +174,10 @@ def __init__(self, batch_size=128, max_input_sequence_len=5, max_output_sequence # Basic Decoder decoder = tf.contrib.seq2seq.BasicDecoder(dec_cell, helper, dec_cell.zero_state(self.batch_size,tf.float32)) # Decode - _, states, _ = tf.contrib.seq2seq.dynamic_decode(decoder,impute_finished=True) - # Locate the cell state containing logits - if num_layers>1: - state=states[-1] - else: - state=states - # Fetch logits - # Shape: [batch_size,cur_batch_max_len,max_input_sequence_len+1] - logits = tf.transpose(state.alignment_history.stack(), [1,0,2]) + outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder,impute_finished=True) + # logits + logits = outputs.rnn_output + # predicted_ids_with_logits self.predicted_ids_with_logits=tf.nn.top_k(logits) # Pad logits to the same shape as targets logits = tf.concat([logits,tf.ones([self.batch_size,self.max_output_sequence_len-cur_batch_max_len,self.max_input_sequence_len+1])],axis=1) @@ -149,8 +200,6 @@ def __init__(self, batch_size=128, max_input_sequence_len=5, max_output_sequence optimizer = tf.train.AdamOptimizer(self.init_learning_rate) # Update operator self.update = optimizer.apply_gradients(zip(clipped_gradients, parameters),global_step=self.global_step) - # Saver - self.saver = tf.train.Saver(tf.global_variables()) # Summarize tf.summary.scalar('loss',self.loss) for p in parameters: @@ -163,6 +212,9 @@ def __init__(self, batch_size=128, max_input_sequence_len=5, max_output_sequence self.debug_var = logits #/DEBUG PART + # Saver + self.saver = tf.train.Saver(tf.global_variables()) + def step(self, session, inputs, enc_input_weights, outputs=None, dec_input_weights=None): """Run a step of the model feeding the given inputs. @@ -174,8 +226,15 @@ def step(self, session, inputs, enc_input_weights, outputs=None, dec_input_weigh dec_input_weights: the weights of decoder input points. shape: [batch_size,max_output_sequence_len] Returns: - A triple + (training) + The summary + The total loss + The predicted ids with logits + The targets + The variable for debugging + (evaluation) + The predicted ids """ #Fill up inputs input_feed = {} @@ -187,7 +246,7 @@ def step(self, session, inputs, enc_input_weights, outputs=None, dec_input_weigh #Fill up outputs if self.forward_only: - raise NotImplementedError() + output_feed = [self.predicted_ids] else: output_feed = [self.update, self.summary_op, self.loss, self.predicted_ids_with_logits, self.shifted_targets, self.debug_var] @@ -197,6 +256,6 @@ def step(self, session, inputs, enc_input_weights, outputs=None, dec_input_weigh #Return if self.forward_only: - pass + return outputs[0] else: return outputs[1],outputs[2],outputs[3],outputs[4],outputs[5]