diff --git a/README.md b/README.md index 9adca7f45..1fdd7e883 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ t2t-datagen \ --num_shards=100 \ --problem=$PROBLEM -mv $TMP_DIR/tokens.vocab.32768 $DATA_DIR +cp $TMP_DIR/tokens.vocab.* $DATA_DIR # Train # * If you run out of memory, add --hparams='batch_size=2048' or even 1024. diff --git a/setup.py b/setup.py index beb3513e1..821a88ee2 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.0.11', + version='1.0.12', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen index f45f63744..4e7e4529a 100644 --- a/tensor2tensor/bin/t2t-datagen +++ b/tensor2tensor/bin/t2t-datagen @@ -90,25 +90,16 @@ _SUPPORTED_PROBLEM_GENERATORS = { "algorithmic_reverse_nlplike_decimal8K": ( lambda: algorithmic.reverse_generator_nlplike(8000, 70, 100000, 10, 1.300), - lambda: algorithmic.reverse_generator_nlplike(8000, 700, 10000, + lambda: algorithmic.reverse_generator_nlplike(8000, 70, 10000, 10, 1.300)), "algorithmic_reverse_nlplike_decimal32K": ( lambda: algorithmic.reverse_generator_nlplike(32000, 70, 100000, 10, 1.050), - lambda: algorithmic.reverse_generator_nlplike(32000, 700, 10000, + lambda: algorithmic.reverse_generator_nlplike(32000, 70, 10000, 10, 1.050)), "algorithmic_algebra_inverse": ( lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000), lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)), - "algorithmic_algebra_simplify": ( - lambda: algorithmic_math.algebra_simplify(8, 0, 2, 100000), - lambda: algorithmic_math.algebra_simplify(8, 3, 3, 10000)), - "algorithmic_calculus_integrate": ( - lambda: algorithmic_math.calculus_integrate(8, 0, 2, 100000), - lambda: algorithmic_math.calculus_integrate(8, 3, 3, 10000)), - "wmt_parsing_characters": ( - lambda: wmt.parsing_character_generator(FLAGS.tmp_dir, True), - lambda: wmt.parsing_character_generator(FLAGS.tmp_dir, False)), "wmt_parsing_tokens_8k": ( lambda: wmt.parsing_token_generator(FLAGS.tmp_dir, True, 2**13), lambda: wmt.parsing_token_generator(FLAGS.tmp_dir, False, 2**13)), @@ -133,10 +124,6 @@ _SUPPORTED_PROBLEM_GENERATORS = { lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15), lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15) ), - "wmt_enfr_tokens_128k": ( - lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**17), - lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**17) - ), "wmt_ende_characters": ( lambda: wmt.ende_character_generator(FLAGS.tmp_dir, True), lambda: wmt.ende_character_generator(FLAGS.tmp_dir, False)), @@ -151,10 +138,6 @@ _SUPPORTED_PROBLEM_GENERATORS = { lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15), lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15) ), - "wmt_ende_tokens_128k": ( - lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**17), - lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**17) - ), "image_mnist_tune": ( lambda: image.mnist_generator(FLAGS.tmp_dir, True, 55000), lambda: image.mnist_generator(FLAGS.tmp_dir, True, 5000, 55000)), @@ -227,33 +210,6 @@ _SUPPORTED_PROBLEM_GENERATORS = { 40000, vocab_filename="tokens.vocab.%d" % 2**15, vocab_size=2**15)), - "image_mscoco_tokens_128k_tune": ( - lambda: image.mscoco_generator( - FLAGS.tmp_dir, - True, - 70000, - vocab_filename="tokens.vocab.%d" % 2**17, - vocab_size=2**17), - lambda: image.mscoco_generator( - FLAGS.tmp_dir, - True, - 10000, - 70000, - vocab_filename="tokens.vocab.%d" % 2**17, - vocab_size=2**17)), - "image_mscoco_tokens_128k_test": ( - lambda: image.mscoco_generator( - FLAGS.tmp_dir, - True, - 80000, - vocab_filename="tokens.vocab.%d" % 2**17, - vocab_size=2**17), - lambda: image.mscoco_generator( - FLAGS.tmp_dir, - False, - 40000, - vocab_filename="tokens.vocab.%d" % 2**17, - vocab_size=2**17)), "snli_32k": ( lambda: snli.snli_token_generator(FLAGS.tmp_dir, True, 2**15), lambda: snli.snli_token_generator(FLAGS.tmp_dir, False, 2**15), @@ -340,10 +296,31 @@ def set_random_seed(): def main(_): tf.logging.set_verbosity(tf.logging.INFO) - if FLAGS.problem not in _SUPPORTED_PROBLEM_GENERATORS: + + # Calculate the list of problems to generate. + problems = list(sorted(_SUPPORTED_PROBLEM_GENERATORS)) + if FLAGS.problem and FLAGS.problem[-1] == "*": + problems = [p for p in problems if p.startswith(FLAGS.problem[:-1])] + elif FLAGS.problem: + problems = [p for p in problems if p == FLAGS.problem] + else: + problems = [] + # Remove TIMIT if paths are not given. + if not FLAGS.timit_paths: + problems = [p for p in problems if "timit" not in p] + # Remove parsing if paths are not given. + if not FLAGS.parsing_path: + problems = [p for p in problems if "parsing" not in p] + # Remove en-de BPE if paths are not given. + if not FLAGS.ende_bpe_path: + problems = [p for p in problems if "ende_bpe" not in p] + + if not problems: problems_str = "\n * ".join(sorted(_SUPPORTED_PROBLEM_GENERATORS)) error_msg = ("You must specify one of the supported problems to " "generate data for:\n * " + problems_str + "\n") + error_msg += ("TIMIT, ende_bpe and parsing need data_sets specified with " + "--timit_paths, --ende_bpe_path and --parsing_path.") raise ValueError(error_msg) if not FLAGS.data_dir: @@ -352,26 +329,28 @@ def main(_): "Data will be written to default data_dir=%s.", FLAGS.data_dir) - set_random_seed() + tf.logging.info("Generating problems:\n * %s\n" % "\n * ".join(problems)) + for problem in problems: + set_random_seed() - training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[FLAGS.problem] + training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[problem] - tf.logging.info("Generating training data for %s.", FLAGS.problem) - train_output_files = generator_utils.generate_files( - training_gen(), FLAGS.problem + UNSHUFFLED_SUFFIX + "-train", - FLAGS.data_dir, FLAGS.num_shards, FLAGS.max_cases) + tf.logging.info("Generating training data for %s.", problem) + train_output_files = generator_utils.generate_files( + training_gen(), problem + UNSHUFFLED_SUFFIX + "-train", + FLAGS.data_dir, FLAGS.num_shards, FLAGS.max_cases) - tf.logging.info("Generating development data for %s.", FLAGS.problem) - dev_output_files = generator_utils.generate_files( - dev_gen(), FLAGS.problem + UNSHUFFLED_SUFFIX + "-dev", FLAGS.data_dir, 1) + tf.logging.info("Generating development data for %s.", problem) + dev_output_files = generator_utils.generate_files( + dev_gen(), problem + UNSHUFFLED_SUFFIX + "-dev", FLAGS.data_dir, 1) - tf.logging.info("Shuffling data...") - for fname in train_output_files + dev_output_files: - records = generator_utils.read_records(fname) - random.shuffle(records) - out_fname = fname.replace(UNSHUFFLED_SUFFIX, "") - generator_utils.write_records(records, out_fname) - tf.gfile.Remove(fname) + tf.logging.info("Shuffling data...") + for fname in train_output_files + dev_output_files: + records = generator_utils.read_records(fname) + random.shuffle(records) + out_fname = fname.replace(UNSHUFFLED_SUFFIX, "") + generator_utils.write_records(records, out_fname) + tf.gfile.Remove(fname) if __name__ == "__main__": diff --git a/tensor2tensor/data_generators/algorithmic_math.py b/tensor2tensor/data_generators/algorithmic_math.py index f5c954036..ec3b7670a 100644 --- a/tensor2tensor/data_generators/algorithmic_math.py +++ b/tensor2tensor/data_generators/algorithmic_math.py @@ -582,4 +582,6 @@ def calculus_integrate(alphabet_size=26, } except: # pylint:disable=bare-except continue + if nbr_case % 10000 == 0: + print(" calculus_integrate: generating case %d." % nbr_case) nbr_case += 1 diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py index bc354a86d..6a3475456 100644 --- a/tensor2tensor/data_generators/generator_utils.py +++ b/tensor2tensor/data_generators/generator_utils.py @@ -46,12 +46,13 @@ def to_example(dictionary): elif isinstance(v[0], float): features[k] = tf.train.Feature(float_list=tf.train.FloatList(value=v)) elif isinstance(v[0], six.string_types): - v = [bytes(x, 'utf-8') for x in v] + if not six.PY2: # Convert in python 3. + v = [bytes(x, "utf-8") for x in v] features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=v)) elif isinstance(v[0], bytes): features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=v)) else: - raise ValueError("Value for %s is neither an int nor a float; v: %s type: %s" % + raise ValueError("Value for %s is not a recognized type; v: %s type: %s" % (k, str(v[0]), str(type(v[0])))) return tf.train.Example(features=tf.train.Features(feature=features)) @@ -114,7 +115,7 @@ def generate_files(generator, counter, shard = 0, 0 for case in generator: - if counter % 100000 == 0: + if counter > 0 and counter % 100000 == 0: tf.logging.info("Generating case %d for %s." % (counter, output_name)) counter += 1 if max_cases and counter > max_cases: @@ -179,6 +180,9 @@ def gunzip_file(gz_path, new_path): gz_path: path to the zipped file. new_path: path to where the file will be unzipped. """ + if tf.gfile.Exists(new_path): + tf.logging.info("File %s already exists, skipping unpacking" % new_path) + return tf.logging.info("Unpacking %s to %s" % (gz_path, new_path)) with gzip.open(gz_path, "rb") as gz_file: with io.open(new_path, "wb") as new_file: @@ -224,7 +228,7 @@ def gunzip_file(gz_path, new_path): def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None): """Generate a vocabulary from the datasets in sources (_DATA_FILE_URLS).""" vocab_filepath = os.path.join(tmp_dir, vocab_filename) - if os.path.exists(vocab_filepath): + if tf.gfile.Exists(vocab_filepath): tf.logging.info("Found vocab file: %s", vocab_filepath) vocab = text_encoder.SubwordTextEncoder(vocab_filepath) return vocab @@ -249,7 +253,7 @@ def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None): # For some datasets a second extraction is necessary. if ".gz" in lang_file: new_filepath = os.path.join(tmp_dir, lang_file[:-3]) - if os.path.exists(new_filepath): + if tf.gfile.Exists(new_filepath): tf.logging.info("Subdirectory %s already exists, skipping unpacking" % filepath) else: @@ -278,7 +282,7 @@ def read_records(filename): records = [] for record in reader: records.append(record) - if len(records) % 10000 == 0: + if len(records) % 100000 == 0: tf.logging.info("read: %d", len(records)) return records @@ -287,6 +291,6 @@ def write_records(records, out_filename): writer = tf.python_io.TFRecordWriter(out_filename) for count, record in enumerate(records): writer.write(record) - if count % 10000 == 0: + if count > 0 and count % 100000 == 0: tf.logging.info("write: %d", count) writer.close() diff --git a/tensor2tensor/data_generators/wmt.py b/tensor2tensor/data_generators/wmt.py index e88a90983..1937e1b71 100644 --- a/tensor2tensor/data_generators/wmt.py +++ b/tensor2tensor/data_generators/wmt.py @@ -25,10 +25,19 @@ from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import text_encoder +from tensor2tensor.data_generators import wsj_parsing import tensorflow as tf +tf.flags.DEFINE_string("ende_bpe_path", "", "Path to BPE files in tmp_dir." + "Download from https://drive.google.com/open?" + "id=0B_bZck-ksdkpM25jRUN2X2UxMm8") + + +FLAGS = tf.flags.FLAGS + + # End-of-sentence marker (should correspond to the position of EOS in the # RESERVED_TOKENS list in text_encoder.py) EOS = 1 @@ -100,7 +109,7 @@ def _get_wmt_ende_dataset(directory, filename): # We expect that this file has been downloaded from: # https://drive.google.com/open?id=0B_bZck-ksdkpM25jRUN2X2UxMm8 and placed # in `directory`. - corpus_file = os.path.join(directory, "wmt16_en_de.tar.gz") + corpus_file = os.path.join(directory, FLAGS.ende_bpe_path) with tarfile.open(corpus_file, "r:gz") as corpus_tar: corpus_tar.extractall(directory) return train_path @@ -265,18 +274,10 @@ def enfr_character_generator(tmp_dir, train): character_vocab, EOS) -def parsing_character_generator(tmp_dir, train): - character_vocab = text_encoder.ByteTextEncoder() - filename = "parsing_%s" % ("train" if train else "dev") - text_filepath = os.path.join(tmp_dir, filename + ".text") - tags_filepath = os.path.join(tmp_dir, filename + ".tags") - return character_generator(text_filepath, tags_filepath, character_vocab, EOS) - - def parsing_token_generator(tmp_dir, train, vocab_size): symbolizer_vocab = generator_utils.get_or_generate_vocab( tmp_dir, "tokens.vocab.%d" % vocab_size, vocab_size) - filename = "parsing_%s" % ("train" if train else "dev") - text_filepath = os.path.join(tmp_dir, filename + ".text") - tags_filepath = os.path.join(tmp_dir, filename + ".tags") - return token_generator(text_filepath, tags_filepath, symbolizer_vocab, EOS) + filename = "%s_%s.trees" % (FLAGS.parsing_path, "train" if train else "dev") + tree_filepath = os.path.join(tmp_dir, filename) + return wsj_parsing.token_generator(tree_filepath, + symbolizer_vocab, symbolizer_vocab, EOS) diff --git a/tensor2tensor/data_generators/wsj_parsing.py b/tensor2tensor/data_generators/wsj_parsing.py index a2dda4d9d..756a44954 100644 --- a/tensor2tensor/data_generators/wsj_parsing.py +++ b/tensor2tensor/data_generators/wsj_parsing.py @@ -23,6 +23,12 @@ import tensorflow as tf +tf.flags.DEFINE_string("parsing_path", "", "Path to parsing files in tmp_dir.") + + +FLAGS = tf.flags.FLAGS + + def words_and_tags_from_wsj_tree(tree_string): """Generates linearized trees and tokens from the wsj tree format. @@ -84,9 +90,8 @@ def parsing_token_generator(tmp_dir, train, source_vocab_size, target_vocab_size): """Generator for parsing as a sequence-to-sequence task that uses tokens. - This generator assumes the files parsing_{train,dev}.wsj, which contain trees - in wsj format and wsj_{source,target}.tokens.vocab. exist in - tmp_dir. + This generator assumes the files parsing_{train,dev}.trees, which contain + trees in wsj format. Args: tmp_dir: path to the file with source sentences. @@ -103,7 +108,7 @@ def parsing_token_generator(tmp_dir, train, source_vocab_size, target_symbolizer_vocab = generator_utils.get_or_generate_vocab( tmp_dir, "wsj_target.tokens.vocab.%d" % target_vocab_size, target_vocab_size) - filename = "parsing_%s.trees" % ("train" if train else "dev") + filename = "%s_%s.trees" % (FLAGS.parsing_path, "train" if train else "dev") tree_filepath = os.path.join(tmp_dir, filename) return token_generator(tree_filepath, source_symbolizer_vocab, target_symbolizer_vocab, 1) diff --git a/tensor2tensor/models/lstm.py b/tensor2tensor/models/lstm.py index eb8b10cd2..998e6756b 100644 --- a/tensor2tensor/models/lstm.py +++ b/tensor2tensor/models/lstm.py @@ -268,7 +268,7 @@ def model_fn_body(self, features): def lstm_attention(): """hparams for LSTM with attention.""" hparams = common_hparams.basic_params1() - hparams.batch_size = 128 + hparams.batch_size = 1024 hparams.hidden_size = 128 hparams.num_hidden_layers = 2 diff --git a/tensor2tensor/models/modalities.py b/tensor2tensor/models/modalities.py index fd9fb4432..4e7a7e924 100644 --- a/tensor2tensor/models/modalities.py +++ b/tensor2tensor/models/modalities.py @@ -124,6 +124,10 @@ def top(self, body_output, targets): class SmallImageModality(modality.Modality): """Performs strided conv compressions for small image data.""" + def __init__(self, model_hparams, vocab_size): + super(SmallImageModality, self).__init__(model_hparams, vocab_size) + self._channels = 3 + @property def top_dimensionality(self): return 256 @@ -161,15 +165,30 @@ def targets_bottom(self, inputs): def top(self, body_output, _): with tf.variable_scope("rgb_softmax"): - var = tf.get_variable( + # seperate embedding for each channel + # assuming the body output returns a tensor of shape + # [batch_size, rows, cols, channels, self._body_input_depth] + body_output_split = tf.split(body_output, self._channels, axis=3) + output_rgb_embedding_var = tf.get_variable( "output_rgb_embedding", - [self.top_dimensionality, self._body_input_depth], + [self._channels, self.top_dimensionality, self._body_input_depth], initializer=tf.random_normal_initializer(0.0, self._body_input_depth **-0.5)) - body_output = tf.reshape(body_output, [-1, self._body_input_depth]) - logits = tf.matmul(body_output, var, transpose_b=True) + # compute logits separately for each channel + rgb_channel_logits = [] + for i in self._channels: + shape = tf.shape(body_output_split[i])[:-1] + body_output = tf.reshape(body_output_split[i], + [-1, self._body_input_depth]) + channel_logits = tf.matmul(body_output, + output_rgb_embedding_var[i], + transpose_b=True) + rgb_channel_logits.append(tf.reshape( + channel_logits, tf.concat([shape, [self.top_dimensionality]], + 0))) + + logits = tf.concat(rgb_channel_logits, axis=3) # Reshape logits to conform to CIFAR image shapes (32 by 32 by 3) - logits = tf.reshape(logits, [-1, 32, 32, 3, 256]) return logits diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 88d901df9..544035efd 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -48,8 +48,8 @@ def model_fn_body(self, features): inputs = features.get("inputs") target_space = features.get("target_space_id") - inputs = tf.squeeze(inputs, 2) - targets = tf.squeeze(targets, 2) + inputs = common_layers.flatten4d3d(inputs) + targets = common_layers.flatten4d3d(targets) (encoder_input, encoder_attention_bias, _) = (transformer_prepare_encoder( inputs, target_space, hparams)) diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index 8b6422734..a991d3614 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -124,6 +124,8 @@ def _create_modalities(self, problem_hparams, hparams): problem_hparams.input_modality = input_modality target_modality_spec = problem_hparams.target_modality + if isinstance(target_modality_spec, modality.Modality): + return if target_modality_name: _warn_changed_modality_type(target_modality_name, target_modality_spec[0], "target") diff --git a/tensor2tensor/utils/trainer_utils.py b/tensor2tensor/utils/trainer_utils.py index fc6970188..75883accd 100644 --- a/tensor2tensor/utils/trainer_utils.py +++ b/tensor2tensor/utils/trainer_utils.py @@ -69,6 +69,7 @@ flags.DEFINE_integer("train_steps", 250000, "The number of steps to run training for.") flags.DEFINE_integer("eval_steps", 10, "Number of steps in evaluation.") +flags.DEFINE_bool("eval_print", False, "Print eval logits and predictions.") flags.DEFINE_integer("keep_checkpoint_max", 20, "How many recent checkpoints to keep.") flags.DEFINE_bool("experimental_optimize_placement", False, @@ -452,6 +453,9 @@ def nth_model(n): sharded_logits, total_loss = result_list[1:], result_list[0] if mode == tf.contrib.learn.ModeKeys.EVAL: logits = tf.concat(sharded_logits, 0) + if FLAGS.eval_print: + logits = tf.Print(logits, [features["inputs"], logits], + "EVAL PRINT", summarize=10000) # For evaluation, return the logits layer as our predictions. run_info["predictions"] = logits train_op = None