Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #118 from lukaszkaiser/push
Browse files Browse the repository at this point in the history
1.0.12
  • Loading branch information
lukaszkaiser authored Jul 8, 2017
2 parents fbb6f9a + 54622a5 commit 83ac9fb
Show file tree
Hide file tree
Showing 12 changed files with 113 additions and 97 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ t2t-datagen \
--num_shards=100 \
--problem=$PROBLEM
mv $TMP_DIR/tokens.vocab.32768 $DATA_DIR
cp $TMP_DIR/tokens.vocab.* $DATA_DIR
# Train
# * If you run out of memory, add --hparams='batch_size=2048' or even 1024.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.0.11',
version='1.0.12',
description='Tensor2Tensor',
author='Google Inc.',
author_email='[email protected]',
Expand Down
105 changes: 42 additions & 63 deletions tensor2tensor/bin/t2t-datagen
Original file line number Diff line number Diff line change
Expand Up @@ -90,25 +90,16 @@ _SUPPORTED_PROBLEM_GENERATORS = {
"algorithmic_reverse_nlplike_decimal8K": (
lambda: algorithmic.reverse_generator_nlplike(8000, 70, 100000,
10, 1.300),
lambda: algorithmic.reverse_generator_nlplike(8000, 700, 10000,
lambda: algorithmic.reverse_generator_nlplike(8000, 70, 10000,
10, 1.300)),
"algorithmic_reverse_nlplike_decimal32K": (
lambda: algorithmic.reverse_generator_nlplike(32000, 70, 100000,
10, 1.050),
lambda: algorithmic.reverse_generator_nlplike(32000, 700, 10000,
lambda: algorithmic.reverse_generator_nlplike(32000, 70, 10000,
10, 1.050)),
"algorithmic_algebra_inverse": (
lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000),
lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)),
"algorithmic_algebra_simplify": (
lambda: algorithmic_math.algebra_simplify(8, 0, 2, 100000),
lambda: algorithmic_math.algebra_simplify(8, 3, 3, 10000)),
"algorithmic_calculus_integrate": (
lambda: algorithmic_math.calculus_integrate(8, 0, 2, 100000),
lambda: algorithmic_math.calculus_integrate(8, 3, 3, 10000)),
"wmt_parsing_characters": (
lambda: wmt.parsing_character_generator(FLAGS.tmp_dir, True),
lambda: wmt.parsing_character_generator(FLAGS.tmp_dir, False)),
"wmt_parsing_tokens_8k": (
lambda: wmt.parsing_token_generator(FLAGS.tmp_dir, True, 2**13),
lambda: wmt.parsing_token_generator(FLAGS.tmp_dir, False, 2**13)),
Expand All @@ -133,10 +124,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15),
lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15)
),
"wmt_enfr_tokens_128k": (
lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**17),
lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**17)
),
"wmt_ende_characters": (
lambda: wmt.ende_character_generator(FLAGS.tmp_dir, True),
lambda: wmt.ende_character_generator(FLAGS.tmp_dir, False)),
Expand All @@ -151,10 +138,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15),
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15)
),
"wmt_ende_tokens_128k": (
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**17),
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**17)
),
"image_mnist_tune": (
lambda: image.mnist_generator(FLAGS.tmp_dir, True, 55000),
lambda: image.mnist_generator(FLAGS.tmp_dir, True, 5000, 55000)),
Expand Down Expand Up @@ -227,33 +210,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
40000,
vocab_filename="tokens.vocab.%d" % 2**15,
vocab_size=2**15)),
"image_mscoco_tokens_128k_tune": (
lambda: image.mscoco_generator(
FLAGS.tmp_dir,
True,
70000,
vocab_filename="tokens.vocab.%d" % 2**17,
vocab_size=2**17),
lambda: image.mscoco_generator(
FLAGS.tmp_dir,
True,
10000,
70000,
vocab_filename="tokens.vocab.%d" % 2**17,
vocab_size=2**17)),
"image_mscoco_tokens_128k_test": (
lambda: image.mscoco_generator(
FLAGS.tmp_dir,
True,
80000,
vocab_filename="tokens.vocab.%d" % 2**17,
vocab_size=2**17),
lambda: image.mscoco_generator(
FLAGS.tmp_dir,
False,
40000,
vocab_filename="tokens.vocab.%d" % 2**17,
vocab_size=2**17)),
"snli_32k": (
lambda: snli.snli_token_generator(FLAGS.tmp_dir, True, 2**15),
lambda: snli.snli_token_generator(FLAGS.tmp_dir, False, 2**15),
Expand Down Expand Up @@ -340,10 +296,31 @@ def set_random_seed():

def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
if FLAGS.problem not in _SUPPORTED_PROBLEM_GENERATORS:

# Calculate the list of problems to generate.
problems = list(sorted(_SUPPORTED_PROBLEM_GENERATORS))
if FLAGS.problem and FLAGS.problem[-1] == "*":
problems = [p for p in problems if p.startswith(FLAGS.problem[:-1])]
elif FLAGS.problem:
problems = [p for p in problems if p == FLAGS.problem]
else:
problems = []
# Remove TIMIT if paths are not given.
if not FLAGS.timit_paths:
problems = [p for p in problems if "timit" not in p]
# Remove parsing if paths are not given.
if not FLAGS.parsing_path:
problems = [p for p in problems if "parsing" not in p]
# Remove en-de BPE if paths are not given.
if not FLAGS.ende_bpe_path:
problems = [p for p in problems if "ende_bpe" not in p]

if not problems:
problems_str = "\n * ".join(sorted(_SUPPORTED_PROBLEM_GENERATORS))
error_msg = ("You must specify one of the supported problems to "
"generate data for:\n * " + problems_str + "\n")
error_msg += ("TIMIT, ende_bpe and parsing need data_sets specified with "
"--timit_paths, --ende_bpe_path and --parsing_path.")
raise ValueError(error_msg)

if not FLAGS.data_dir:
Expand All @@ -352,26 +329,28 @@ def main(_):
"Data will be written to default data_dir=%s.",
FLAGS.data_dir)

set_random_seed()
tf.logging.info("Generating problems:\n * %s\n" % "\n * ".join(problems))
for problem in problems:
set_random_seed()

training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[FLAGS.problem]
training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[problem]

tf.logging.info("Generating training data for %s.", FLAGS.problem)
train_output_files = generator_utils.generate_files(
training_gen(), FLAGS.problem + UNSHUFFLED_SUFFIX + "-train",
FLAGS.data_dir, FLAGS.num_shards, FLAGS.max_cases)
tf.logging.info("Generating training data for %s.", problem)
train_output_files = generator_utils.generate_files(
training_gen(), problem + UNSHUFFLED_SUFFIX + "-train",
FLAGS.data_dir, FLAGS.num_shards, FLAGS.max_cases)

tf.logging.info("Generating development data for %s.", FLAGS.problem)
dev_output_files = generator_utils.generate_files(
dev_gen(), FLAGS.problem + UNSHUFFLED_SUFFIX + "-dev", FLAGS.data_dir, 1)
tf.logging.info("Generating development data for %s.", problem)
dev_output_files = generator_utils.generate_files(
dev_gen(), problem + UNSHUFFLED_SUFFIX + "-dev", FLAGS.data_dir, 1)

tf.logging.info("Shuffling data...")
for fname in train_output_files + dev_output_files:
records = generator_utils.read_records(fname)
random.shuffle(records)
out_fname = fname.replace(UNSHUFFLED_SUFFIX, "")
generator_utils.write_records(records, out_fname)
tf.gfile.Remove(fname)
tf.logging.info("Shuffling data...")
for fname in train_output_files + dev_output_files:
records = generator_utils.read_records(fname)
random.shuffle(records)
out_fname = fname.replace(UNSHUFFLED_SUFFIX, "")
generator_utils.write_records(records, out_fname)
tf.gfile.Remove(fname)


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions tensor2tensor/data_generators/algorithmic_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,4 +582,6 @@ def calculus_integrate(alphabet_size=26,
}
except: # pylint:disable=bare-except
continue
if nbr_case % 10000 == 0:
print(" calculus_integrate: generating case %d." % nbr_case)
nbr_case += 1
18 changes: 11 additions & 7 deletions tensor2tensor/data_generators/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ def to_example(dictionary):
elif isinstance(v[0], float):
features[k] = tf.train.Feature(float_list=tf.train.FloatList(value=v))
elif isinstance(v[0], six.string_types):
v = [bytes(x, 'utf-8') for x in v]
if not six.PY2: # Convert in python 3.
v = [bytes(x, "utf-8") for x in v]
features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=v))
elif isinstance(v[0], bytes):
features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=v))
else:
raise ValueError("Value for %s is neither an int nor a float; v: %s type: %s" %
raise ValueError("Value for %s is not a recognized type; v: %s type: %s" %
(k, str(v[0]), str(type(v[0]))))
return tf.train.Example(features=tf.train.Features(feature=features))

Expand Down Expand Up @@ -114,7 +115,7 @@ def generate_files(generator,

counter, shard = 0, 0
for case in generator:
if counter % 100000 == 0:
if counter > 0 and counter % 100000 == 0:
tf.logging.info("Generating case %d for %s." % (counter, output_name))
counter += 1
if max_cases and counter > max_cases:
Expand Down Expand Up @@ -179,6 +180,9 @@ def gunzip_file(gz_path, new_path):
gz_path: path to the zipped file.
new_path: path to where the file will be unzipped.
"""
if tf.gfile.Exists(new_path):
tf.logging.info("File %s already exists, skipping unpacking" % new_path)
return
tf.logging.info("Unpacking %s to %s" % (gz_path, new_path))
with gzip.open(gz_path, "rb") as gz_file:
with io.open(new_path, "wb") as new_file:
Expand Down Expand Up @@ -224,7 +228,7 @@ def gunzip_file(gz_path, new_path):
def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None):
"""Generate a vocabulary from the datasets in sources (_DATA_FILE_URLS)."""
vocab_filepath = os.path.join(tmp_dir, vocab_filename)
if os.path.exists(vocab_filepath):
if tf.gfile.Exists(vocab_filepath):
tf.logging.info("Found vocab file: %s", vocab_filepath)
vocab = text_encoder.SubwordTextEncoder(vocab_filepath)
return vocab
Expand All @@ -249,7 +253,7 @@ def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None):
# For some datasets a second extraction is necessary.
if ".gz" in lang_file:
new_filepath = os.path.join(tmp_dir, lang_file[:-3])
if os.path.exists(new_filepath):
if tf.gfile.Exists(new_filepath):
tf.logging.info("Subdirectory %s already exists, skipping unpacking"
% filepath)
else:
Expand Down Expand Up @@ -278,7 +282,7 @@ def read_records(filename):
records = []
for record in reader:
records.append(record)
if len(records) % 10000 == 0:
if len(records) % 100000 == 0:
tf.logging.info("read: %d", len(records))
return records

Expand All @@ -287,6 +291,6 @@ def write_records(records, out_filename):
writer = tf.python_io.TFRecordWriter(out_filename)
for count, record in enumerate(records):
writer.write(record)
if count % 10000 == 0:
if count > 0 and count % 100000 == 0:
tf.logging.info("write: %d", count)
writer.close()
27 changes: 14 additions & 13 deletions tensor2tensor/data_generators/wmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,19 @@

from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.data_generators import wsj_parsing

import tensorflow as tf


tf.flags.DEFINE_string("ende_bpe_path", "", "Path to BPE files in tmp_dir."
"Download from https://drive.google.com/open?"
"id=0B_bZck-ksdkpM25jRUN2X2UxMm8")


FLAGS = tf.flags.FLAGS


# End-of-sentence marker (should correspond to the position of EOS in the
# RESERVED_TOKENS list in text_encoder.py)
EOS = 1
Expand Down Expand Up @@ -100,7 +109,7 @@ def _get_wmt_ende_dataset(directory, filename):
# We expect that this file has been downloaded from:
# https://drive.google.com/open?id=0B_bZck-ksdkpM25jRUN2X2UxMm8 and placed
# in `directory`.
corpus_file = os.path.join(directory, "wmt16_en_de.tar.gz")
corpus_file = os.path.join(directory, FLAGS.ende_bpe_path)
with tarfile.open(corpus_file, "r:gz") as corpus_tar:
corpus_tar.extractall(directory)
return train_path
Expand Down Expand Up @@ -265,18 +274,10 @@ def enfr_character_generator(tmp_dir, train):
character_vocab, EOS)


def parsing_character_generator(tmp_dir, train):
character_vocab = text_encoder.ByteTextEncoder()
filename = "parsing_%s" % ("train" if train else "dev")
text_filepath = os.path.join(tmp_dir, filename + ".text")
tags_filepath = os.path.join(tmp_dir, filename + ".tags")
return character_generator(text_filepath, tags_filepath, character_vocab, EOS)


def parsing_token_generator(tmp_dir, train, vocab_size):
symbolizer_vocab = generator_utils.get_or_generate_vocab(
tmp_dir, "tokens.vocab.%d" % vocab_size, vocab_size)
filename = "parsing_%s" % ("train" if train else "dev")
text_filepath = os.path.join(tmp_dir, filename + ".text")
tags_filepath = os.path.join(tmp_dir, filename + ".tags")
return token_generator(text_filepath, tags_filepath, symbolizer_vocab, EOS)
filename = "%s_%s.trees" % (FLAGS.parsing_path, "train" if train else "dev")
tree_filepath = os.path.join(tmp_dir, filename)
return wsj_parsing.token_generator(tree_filepath,
symbolizer_vocab, symbolizer_vocab, EOS)
13 changes: 9 additions & 4 deletions tensor2tensor/data_generators/wsj_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
import tensorflow as tf


tf.flags.DEFINE_string("parsing_path", "", "Path to parsing files in tmp_dir.")


FLAGS = tf.flags.FLAGS


def words_and_tags_from_wsj_tree(tree_string):
"""Generates linearized trees and tokens from the wsj tree format.
Expand Down Expand Up @@ -84,9 +90,8 @@ def parsing_token_generator(tmp_dir, train, source_vocab_size,
target_vocab_size):
"""Generator for parsing as a sequence-to-sequence task that uses tokens.
This generator assumes the files parsing_{train,dev}.wsj, which contain trees
in wsj format and wsj_{source,target}.tokens.vocab.<vocab_size> exist in
tmp_dir.
This generator assumes the files parsing_{train,dev}.trees, which contain
trees in wsj format.
Args:
tmp_dir: path to the file with source sentences.
Expand All @@ -103,7 +108,7 @@ def parsing_token_generator(tmp_dir, train, source_vocab_size,
target_symbolizer_vocab = generator_utils.get_or_generate_vocab(
tmp_dir, "wsj_target.tokens.vocab.%d" % target_vocab_size,
target_vocab_size)
filename = "parsing_%s.trees" % ("train" if train else "dev")
filename = "%s_%s.trees" % (FLAGS.parsing_path, "train" if train else "dev")
tree_filepath = os.path.join(tmp_dir, filename)
return token_generator(tree_filepath, source_symbolizer_vocab,
target_symbolizer_vocab, 1)
2 changes: 1 addition & 1 deletion tensor2tensor/models/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def model_fn_body(self, features):
def lstm_attention():
"""hparams for LSTM with attention."""
hparams = common_hparams.basic_params1()
hparams.batch_size = 128
hparams.batch_size = 1024
hparams.hidden_size = 128
hparams.num_hidden_layers = 2

Expand Down
29 changes: 24 additions & 5 deletions tensor2tensor/models/modalities.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ def top(self, body_output, targets):
class SmallImageModality(modality.Modality):
"""Performs strided conv compressions for small image data."""

def __init__(self, model_hparams, vocab_size):
super(SmallImageModality, self).__init__(model_hparams, vocab_size)
self._channels = 3

@property
def top_dimensionality(self):
return 256
Expand Down Expand Up @@ -161,15 +165,30 @@ def targets_bottom(self, inputs):

def top(self, body_output, _):
with tf.variable_scope("rgb_softmax"):
var = tf.get_variable(
# seperate embedding for each channel
# assuming the body output returns a tensor of shape
# [batch_size, rows, cols, channels, self._body_input_depth]
body_output_split = tf.split(body_output, self._channels, axis=3)
output_rgb_embedding_var = tf.get_variable(
"output_rgb_embedding",
[self.top_dimensionality, self._body_input_depth],
[self._channels, self.top_dimensionality, self._body_input_depth],
initializer=tf.random_normal_initializer(0.0, self._body_input_depth
**-0.5))
body_output = tf.reshape(body_output, [-1, self._body_input_depth])
logits = tf.matmul(body_output, var, transpose_b=True)
# compute logits separately for each channel
rgb_channel_logits = []
for i in self._channels:
shape = tf.shape(body_output_split[i])[:-1]
body_output = tf.reshape(body_output_split[i],
[-1, self._body_input_depth])
channel_logits = tf.matmul(body_output,
output_rgb_embedding_var[i],
transpose_b=True)
rgb_channel_logits.append(tf.reshape(
channel_logits, tf.concat([shape, [self.top_dimensionality]],
0)))

logits = tf.concat(rgb_channel_logits, axis=3)
# Reshape logits to conform to CIFAR image shapes (32 by 32 by 3)
logits = tf.reshape(logits, [-1, 32, 32, 3, 256])

return logits

Expand Down
Loading

0 comments on commit 83ac9fb

Please sign in to comment.