Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #204 from rsepassi/push
Browse files Browse the repository at this point in the history
v1.1.5
  • Loading branch information
lukaszkaiser authored Aug 3, 2017
2 parents c35c7a3 + eee190b commit 82cce52
Show file tree
Hide file tree
Showing 42 changed files with 1,659 additions and 1,515 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ _pycache__/

# Python egg metadata, regenerated from source files by setuptools.
/*.egg-info
/*.egg

# PyPI distribution artifacts.
build/
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.1.4',
version='1.1.5',
description='Tensor2Tensor',
author='Google Inc.',
author_email='[email protected]',
Expand All @@ -20,6 +20,7 @@
],
install_requires=[
'numpy',
'requests',
'sympy',
'six',
],
Expand Down
25 changes: 13 additions & 12 deletions tensor2tensor/bin/t2t-datagen
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

"""Produces the training and dev data for --problem into --data_dir.
generator.py produces sharded and shuffled TFRecord files of tensorflow.Example
protocol buffers for a variety of datasets registered in this file.
Produces sharded and shuffled TFRecord files of tensorflow.Example protocol
buffers for a variety of registered datasets.
All datasets are registered in _SUPPORTED_PROBLEM_GENERATORS. Each entry maps a
string name (selectable on the command-line with --problem) to a function that
takes 2 arguments - input_directory and mode (one of "train" or "dev") - and
yields for each training example a dictionary mapping string feature names to
lists of {string, int, float}. The generator will be run once for each mode.
All Problems are registered with @registry.register_problem or are in
_SUPPORTED_PROBLEM_GENERATORS in this file. Each entry maps a string name
(selectable on the command-line with --problem) to a function that takes 2
arguments - input_directory and mode (one of "train" or "dev") - and yields for
each training example a dictionary mapping string feature names to lists of
{string, int, float}. The generator will be run once for each mode.
"""
from __future__ import absolute_import
from __future__ import division
Expand Down Expand Up @@ -229,8 +230,7 @@ def generate_data_for_problem(problem):
num_shards = FLAGS.num_shards or 10
tf.logging.info("Generating training data for %s.", problem)
train_output_files = generator_utils.train_data_filenames(
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir,
num_shards)
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, num_shards)
generator_utils.generate_files(training_gen(), train_output_files,
FLAGS.max_cases)
tf.logging.info("Generating development data for %s.", problem)
Expand All @@ -250,9 +250,10 @@ def generate_data_for_registered_problem(problem_name):
raise ValueError("--num_shards should not be set for registered Problem.")
problem = registry.problem(problem_name)
task_id = None if FLAGS.task_id < 0 else FLAGS.task_id
problem.generate_data(os.path.expanduser(FLAGS.data_dir),
os.path.expanduser(FLAGS.tmp_dir),
task_id=task_id)
problem.generate_data(
os.path.expanduser(FLAGS.data_dir),
os.path.expanduser(FLAGS.tmp_dir),
task_id=task_id)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions tensor2tensor/data_generators/all_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from tensor2tensor.data_generators import wmt
from tensor2tensor.data_generators import wsj_parsing


# Problem modules that require optional dependencies
# pylint: disable=g-import-not-at-top
try:
Expand Down
45 changes: 27 additions & 18 deletions tensor2tensor/data_generators/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.models import common_layers
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import registry

import tensorflow as tf
Expand Down Expand Up @@ -76,10 +76,11 @@ class ImageFSNS(ImageProblem):
def generate_data(self, data_dir, tmp_dir, task_id=-1):
list_url = ("https://raw.githubusercontent.com/tensorflow/models/master/"
"street/python/fsns_urls.txt")
fsns_urls = generator_utils.maybe_download(
tmp_dir, "fsns_urls.txt", list_url)
fsns_files = [f.strip() for f in open(fsns_urls, "r")
if f.startswith("http://")]
fsns_urls = generator_utils.maybe_download(tmp_dir, "fsns_urls.txt",
list_url)
fsns_files = [
f.strip() for f in open(fsns_urls, "r") if f.startswith("http://")
]
for url in fsns_files:
if "/train/train" in url:
generator_utils.maybe_download(
Expand All @@ -88,8 +89,7 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
generator_utils.maybe_download(
data_dir, "image_fsns-dev" + url[-len("-00100-of-00512"):], url)
elif "charset" in url:
generator_utils.maybe_download(
data_dir, "charset_size134.txt", url)
generator_utils.maybe_download(data_dir, "charset_size134.txt", url)

def feature_encoders(self, data_dir):
# This vocab file must be present within the data directory.
Expand All @@ -111,8 +111,8 @@ def hparams(self, defaults, model_hparams):

def example_reading_spec(self):
label_key = "image/unpadded_label"
return super(ImageFSNS, self).example_reading_spec(self,
label_key=label_key)
return super(ImageFSNS, self).example_reading_spec(
self, label_key=label_key)


class Image2ClassProblem(ImageProblem):
Expand Down Expand Up @@ -161,6 +161,7 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):

def imagenet_preprocess_examples(examples, mode):
"""Preprocessing used for Imagenet and similar problems."""

def preprocess(img):
img = tf.image.resize_images(img, [360, 360])
img = common_layers.image_augmentation(tf.to_float(img) / 255.)
Expand Down Expand Up @@ -215,8 +216,8 @@ def is_small(self):

def preprocess_examples(self, examples, mode):
examples = imagenet_preprocess_examples(examples, mode)
examples["inputs"] = tf.to_int64(tf.image.resize_images(
examples["inputs"], [32, 32]))
examples["inputs"] = tf.to_int64(
tf.image.resize_images(examples["inputs"], [32, 32]))


def image_generator(images, labels):
Expand Down Expand Up @@ -665,12 +666,20 @@ def generator(self, data_dir, tmp_dir, is_training):
vocab_filename = "vocab.endefr.%d" % self.targeted_vocab_size
if is_training:
return mscoco_generator(
data_dir, tmp_dir, True, 80000,
vocab_filename=vocab_filename, vocab_size=self.targeted_vocab_size)
data_dir,
tmp_dir,
True,
80000,
vocab_filename=vocab_filename,
vocab_size=self.targeted_vocab_size)
else:
return mscoco_generator(
data_dir, tmp_dir, False, 40000,
vocab_filename=vocab_filename, vocab_size=self.targeted_vocab_size)
data_dir,
tmp_dir,
False,
40000,
vocab_filename=vocab_filename,
vocab_size=self.targeted_vocab_size)


@registry.register_problem
Expand All @@ -690,8 +699,8 @@ def targeted_vocab_size(self):
def _get_celeba(directory):
"""Download and extract CELEBA to directory unless it is there."""
# path = os.path.join(directory, _CELEBA_NAME)
path = generator_utils.maybe_download_from_drive(directory,
_CELEBA_NAME, _CELEBA_URL)
path = generator_utils.maybe_download_from_drive(directory, _CELEBA_NAME,
_CELEBA_URL)
if not tf.gfile.Exists(path):
zipfile.ZipFile(path + ".zip", "r").extractall(directory)

Expand All @@ -711,7 +720,7 @@ def celeba_generator(tmp_dir, how_many, start_from=0):
"""
_get_celeba(tmp_dir)
image_files = tf.gfile.Glob(os.path.join(tmp_dir, _CELEBA_NAME) + "/*.jpg")
for filename in image_files[start_from:start_from+how_many]:
for filename in image_files[start_from:start_from + how_many]:
with tf.gfile.Open(filename, "r") as f:
encoded_image_data = f.read()
yield {
Expand Down
88 changes: 47 additions & 41 deletions tensor2tensor/data_generators/problem_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# Dependency imports

from tensor2tensor.data_generators import text_encoder
from tensor2tensor.models import modalities # pylint: disable=unused-import
from tensor2tensor.layers import modalities # pylint: disable=unused-import
from tensor2tensor.utils import registry

import tensorflow as tf
Expand Down Expand Up @@ -202,8 +202,7 @@ def default_problem_hparams():
# the targets. For instance `problem_copy` will copy the inputs, but
# `problem_rev_copy` will copy the targets.
was_reversed=False,
was_copy=False,
)
was_copy=False,)


def test_problem_hparams(unused_model_hparams, input_vocab_size,
Expand Down Expand Up @@ -327,9 +326,7 @@ def lm1b_32k(model_hparams):
encoder = text_encoder.SubwordTextEncoder(
os.path.join(model_hparams.data_dir, "lm1b_32k.subword_text_encoder"))
p.target_modality = (registry.Modalities.SYMBOL, encoder.vocab_size)
p.vocabulary = {
"targets": encoder
}
p.vocabulary = {"targets": encoder}
p.target_space_id = 3
return p

Expand All @@ -343,9 +340,7 @@ def lm1b_characters(unused_model_hparams):
p.input_modality = {}
encoder = text_encoder.ByteTextEncoder()
p.target_modality = (registry.Modalities.SYMBOL, encoder.vocab_size)
p.vocabulary = {
"targets": encoder
}
p.vocabulary = {"targets": encoder}
p.target_space_id = 2
return p

Expand All @@ -358,10 +353,7 @@ def wiki_32k(model_hparams):
modality_spec = (registry.Modalities.SYMBOL, encoder.vocab_size)
p.input_modality = {"inputs": modality_spec}
p.target_modality = modality_spec
p.vocabulary = {
"inputs": encoder,
"targets": encoder
}
p.vocabulary = {"inputs": encoder, "targets": encoder}
p.target_space_id = 3
return p

Expand Down Expand Up @@ -430,9 +422,7 @@ def wmt_parsing_tokens(model_hparams, wrong_vocab_size):
return p


def wsj_parsing_tokens(model_hparams,
prefix,
wrong_source_vocab_size,
def wsj_parsing_tokens(model_hparams, prefix, wrong_source_vocab_size,
wrong_target_vocab_size):
"""English to parse tree translation benchmark.
Expand Down Expand Up @@ -487,11 +477,9 @@ def ice_parsing_tokens(model_hparams, wrong_source_vocab_size):
p = default_problem_hparams()
# This vocab file must be present within the data directory.
source_vocab_filename = os.path.join(
model_hparams.data_dir,
"ice_source.vocab.%d" % wrong_source_vocab_size)
target_vocab_filename = os.path.join(
model_hparams.data_dir,
"ice_target.vocab.256")
model_hparams.data_dir, "ice_source.vocab.%d" % wrong_source_vocab_size)
target_vocab_filename = os.path.join(model_hparams.data_dir,
"ice_target.vocab.256")
source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename)
target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename)
p.input_modality = {
Expand All @@ -502,7 +490,7 @@ def ice_parsing_tokens(model_hparams, wrong_source_vocab_size):
"inputs": source_subtokenizer,
"targets": target_subtokenizer,
}
p.input_space_id = 18 # Icelandic tokens
p.input_space_id = 18 # Icelandic tokens
p.target_space_id = 19 # Icelandic parse tokens
return p

Expand Down Expand Up @@ -534,23 +522,41 @@ def image_celeba(unused_model_hparams):
# Dictionary of named hyperparameter settings for various problems.
# This is only accessed through the problem_hparams function below.
PROBLEM_HPARAMS_MAP = {
"audio_timit_characters_tune": audio_timit_characters,
"audio_timit_characters_test": audio_timit_characters,
"audio_timit_tokens_8k_tune": lambda p: audio_timit_tokens(p, 2**13),
"audio_timit_tokens_8k_test": lambda p: audio_timit_tokens(p, 2**13),
"audio_wsj_characters_tune": audio_wsj_characters,
"audio_wsj_characters_test": audio_wsj_characters,
"audio_wsj_tokens_8k_tune": lambda p: audio_wsj_tokens(p, 2**13),
"audio_wsj_tokens_8k_test": lambda p: audio_wsj_tokens(p, 2**13),
"lm1b_characters": lm1b_characters,
"lm1b_32k": lm1b_32k,
"wiki_32k": wiki_32k,
"ice_parsing_characters": wmt_parsing_characters,
"ice_parsing_tokens": lambda p: ice_parsing_tokens(p, 2**13),
"wmt_parsing_tokens_8k": lambda p: wmt_parsing_tokens(p, 2**13),
"wsj_parsing_tokens_16k": lambda p: wsj_parsing_tokens( # pylint: disable=g-long-lambda
p, "wsj", 2**14, 2**9),
"wmt_ende_bpe32k": wmt_ende_bpe32k,
"image_celeba_tune": image_celeba,
"img2img_imagenet": img2img_imagenet,
"audio_timit_characters_tune":
audio_timit_characters,
"audio_timit_characters_test":
audio_timit_characters,
"audio_timit_tokens_8k_tune":
lambda p: audio_timit_tokens(p, 2**13),
"audio_timit_tokens_8k_test":
lambda p: audio_timit_tokens(p, 2**13),
"audio_wsj_characters_tune":
audio_wsj_characters,
"audio_wsj_characters_test":
audio_wsj_characters,
"audio_wsj_tokens_8k_tune":
lambda p: audio_wsj_tokens(p, 2**13),
"audio_wsj_tokens_8k_test":
lambda p: audio_wsj_tokens(p, 2**13),
"lm1b_characters":
lm1b_characters,
"lm1b_32k":
lm1b_32k,
"wiki_32k":
wiki_32k,
"ice_parsing_characters":
wmt_parsing_characters,
"ice_parsing_tokens":
lambda p: ice_parsing_tokens(p, 2**13),
"wmt_parsing_tokens_8k":
lambda p: wmt_parsing_tokens(p, 2**13),
"wsj_parsing_tokens_16k":
lambda p: wsj_parsing_tokens( # pylint: disable=g-long-lambda
p, "wsj", 2**14, 2**9),
"wmt_ende_bpe32k":
wmt_ende_bpe32k,
"image_celeba_tune":
image_celeba,
"img2img_imagenet":
img2img_imagenet,
}
Empty file.
Loading

0 comments on commit 82cce52

Please sign in to comment.