Skip to content

Commit

Permalink
Changes to wikitext-103 input pipeline to enable specifying alternati…
Browse files Browse the repository at this point in the history
…ve tokenizers:

- add args to specify tokenizer and vocab path
- change signature of Wikitext word `Tokenizer.tokenize` such that its arg is a tensor instead of a dataset. This allows us to use the `tf Dataset.map` method to tokenize the dataset, consistent w the typical use of other tokenizers like SentencePiece.
- offload flattening of dataset (to prepare for dividing into sequences) from the Wikitext word piece tokenizer

PiperOrigin-RevId: 706822231
  • Loading branch information
priyakasimbeg authored and copybara-github committed Dec 16, 2024
1 parent 43d6a8f commit 15ba961
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 43 deletions.
5 changes: 5 additions & 0 deletions init2winit/dataset_lib/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from init2winit.dataset_lib import small_image_datasets
from init2winit.dataset_lib import translate_wmt
from init2winit.dataset_lib import wikitext103
from init2winit.dataset_lib import wikitext103_spm
from init2winit.dataset_lib import wikitext2

_Dataset = collections.namedtuple(
Expand Down Expand Up @@ -106,6 +107,10 @@
'wikitext103':
_Dataset(wikitext103.get_wikitext103, wikitext103.DEFAULT_HPARAMS,
wikitext2.METADATA, None),
'wikitext103_spm':
_Dataset(wikitext103_spm.get_wikitext103,
wikitext103_spm.DEFAULT_HPARAMS,
wikitext103_spm.METADATA, None),
}


Expand Down
6 changes: 3 additions & 3 deletions init2winit/dataset_lib/test_wikitext_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class TestWikitextTokenizer(absltest.TestCase):
def test_tokenizer_vocab_size(self):
"""Test vocab size.
Vocab size should be number of unique words in text file + 1 for the <eos>
token which gets added at the end of each new line.
Vocab size should be number of unique words in text file + 2 for the <eos>
and <unk> tokens.
"""
# Get number of unique tokens from tokenizer.
text_dataset = tf.data.TextLineDataset(file_name)
Expand All @@ -50,7 +50,7 @@ def test_tokenizer_vocab_size(self):
words = data.split(' ')
num_unique_words = len(set(words))

self.assertEqual(num_unique_tokens, num_unique_words + 1)
self.assertEqual(num_unique_tokens, num_unique_words + 2)

if __name__ == '__main__':
absltest.main()
14 changes: 9 additions & 5 deletions init2winit/dataset_lib/wikitext103.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module containing hyperparameters, metadata and dataset getter for Wikitext-2 dataset."""
"""Module containing hyperparameters, metadata and dataset getter for Wikitext-103 dataset."""

import itertools

from init2winit.dataset_lib import data_utils
from init2winit.dataset_lib import wikitext103_input_pipeline as input_pipeline
from init2winit.dataset_lib.data_utils import Dataset
from init2winit.dataset_lib.wikitext2_input_pipeline import PAD_ID
from init2winit.dataset_lib import wikitext2_input_pipeline
import jax
from ml_collections.config_dict import config_dict
import numpy as np

PAD_ID = wikitext2_input_pipeline.PAD_ID
Dataset = data_utils.Dataset

VOCAB_SIZE = 267735

DEFAULT_HPARAMS = config_dict.ConfigDict(
Expand All @@ -34,9 +36,11 @@
max_eval_target_length=128,
eval_sequence_length=128,
input_shape=(128,),
output_shape=(VOCAB_SIZE,),
vocab_size=VOCAB_SIZE,
output_shape=(input_pipeline.WORD_VOCAB_SIZE,),
train_size=800210, # Number of sequences.
tokenizer='word',
tokenizer_vocab_path=None,
vocab_size=input_pipeline.WORD_VOCAB_SIZE,
))


Expand Down
77 changes: 61 additions & 16 deletions init2winit/dataset_lib/wikitext103_input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,57 @@
"""Module for processing wikitext-2 train, val and test datasets from raw text files to tokenized and batched tensorflow.data.Datasets."""

import os
from typing import Union

from init2winit.dataset_lib import spm_tokenizer
from init2winit.dataset_lib import wikitext_tokenizer
from ml_collections.config_dict import config_dict
import tensorflow as tf

TRAIN_FILENAME = 'train.txt'
VALID_FILENAME = 'valid.txt'
TEST_FILENAME = 'test.txt'

SHUFFLE_BUFFER_SIZE = 1_000_000
SPM_TOKENIZER_VOCAB_SIZE = 10_000
WORD_VOCAB_SIZE = 267_735

PAD_ID = 0

AUTOTUNE = tf.data.experimental.AUTOTUNE


def get_trained_tokenizer(train_dataset: tf.data.Dataset,) -> tf.data.Dataset:
tokenizer = wikitext_tokenizer.Tokenizer()
tokenizer.train(train_dataset)
def get_trained_tokenizer(train_dataset: Union[tf.data.Dataset, str],
tokenizer: str,
vocab_path: str = SPM_TOKENIZER_VOCAB_PATH,
vocab_size: int = SPM_TOKENIZER_VOCAB_SIZE,
max_corpus_chars: int = 10_000_000,
) -> tf.data.Dataset:
"""Returns a tokenizer trained on the train dataset.
Args:
train_dataset: The training dataset to train the tokenizer on.
tokenizer: The type of tokenizer to use. Can be 'word' or 'sentencepiece'.
vocab_path: The path to the vocabulary file.
vocab_size: The size of the vocabulary.
max_corpus_chars: The maximum number of characters to use for training the
tokenizer.
Returns:
A tokenizer trained on the train dataset.
"""
if tokenizer == 'word':
tokenizer = wikitext_tokenizer.Tokenizer()
tokenizer.train(train_dataset)
elif tokenizer == 'sentencepiece':
tokenizer = spm_tokenizer.load_or_train_tokenizer(
dataset=train_dataset,
vocab_path=vocab_path,
vocab_size=vocab_size,
max_corpus_chars=max_corpus_chars,
)
else:
raise ValueError(f'Tokenizer {tokenizer} not supported.')
return tokenizer


Expand All @@ -42,7 +75,7 @@ def batch_with_padding(dataset: tf.data.Dataset,
padded_shapes=None,
padding_id=PAD_ID,
):
"""Batches a tf.data.Dataset and adds padding if len(dataset) not divisible by the batch size.
"""Batches a tf.data.Dataset and adds padding if len(dataset) is not divisible by the batch size.
Args:
dataset: tf.data.Dataset
Expand Down Expand Up @@ -95,33 +128,44 @@ def get_wikitext103_dataset(
test_text_dataset = tf.data.TextLineDataset(test_path)

# Tokenize data
tokenizer = get_trained_tokenizer(train_text_dataset)

train_dataset_tokenized = tokenizer.tokenize(
train_text_dataset)
valid_dataset_tokenized = tokenizer.tokenize(
valid_text_dataset)
test_dataset_tokenized = tokenizer.tokenize(
test_text_dataset)
tokenizer = get_trained_tokenizer(train_text_dataset,
hps.tokenizer,
hps.tokenizer_vocab_path,
hps.vocab_size)
train_dataset_tokenized = train_text_dataset.map(tokenizer.tokenize)
valid_dataset_tokenized = valid_text_dataset.map(tokenizer.tokenize)
test_dataset_tokenized = test_text_dataset.map(tokenizer.tokenize)

# Flatten datasets into tokens so that they can be batched in sequences of
# length sequence_length.
flattened_train_dataset = train_dataset_tokenized.flat_map(
tf.data.Dataset.from_tensor_slices
)
flattened_valid_dataset = valid_dataset_tokenized.flat_map(
tf.data.Dataset.from_tensor_slices
)
flattened_test_dataset = test_dataset_tokenized.flat_map(
tf.data.Dataset.from_tensor_slices
)

# Divide data in sequences of length sequence_length.
train_dataset_sequences = batch_with_padding(
train_dataset_tokenized,
flattened_train_dataset,
hps.sequence_length,
padded_shapes=hps.sequence_length,
)
eval_train_sequences = batch_with_padding(
train_dataset_tokenized,
flattened_train_dataset,
hps.eval_sequence_length,
padded_shapes=hps.eval_sequence_length,
)
valid_dataset_sequences = batch_with_padding(
valid_dataset_tokenized,
flattened_valid_dataset,
hps.eval_sequence_length,
padded_shapes=hps.eval_sequence_length,
)
test_dataset_sequences = batch_with_padding(
test_dataset_tokenized,
flattened_test_dataset,
hps.eval_sequence_length,
padded_shapes=hps.eval_sequence_length,
)
Expand All @@ -146,6 +190,7 @@ def get_wikitext103_dataset(
train_dataset = train_dataset_sequences.batch(
train_batch_size,
drop_remainder=False).prefetch(tf.data.experimental.AUTOTUNE)

# Use padded batches for eval_train, validation and test_datasets since the
# sequences do not repeat indefintely.
eval_train_dataset = batch_with_padding(
Expand Down
49 changes: 49 additions & 0 deletions init2winit/dataset_lib/wikitext103_spm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# coding=utf-8
# Copyright 2024 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module containing hyperparameters, metadata and dataset getter for Wikitext-103 dataset."""


from init2winit.dataset_lib import wikitext103
from init2winit.dataset_lib import wikitext103_input_pipeline
from ml_collections.config_dict import config_dict

SPM_TOKENIZER_VOCAB_SIZE = wikitext103_input_pipeline.SPM_TOKENIZER_VOCAB_SIZE
SPM_TOKENIZER_VOCAB_PATH = wikitext103_input_pipeline.SPM_TOKENIZER_VOCAB_PATH

get_wikitext103 = wikitext103.get_wikitext103

DEFAULT_HPARAMS = config_dict.ConfigDict(
dict(
sequence_length=128,
max_target_length=128,
max_eval_target_length=128,
eval_sequence_length=128,
input_shape=(128,),
output_shape=(SPM_TOKENIZER_VOCAB_SIZE,),
tokenizer='sentencepiece',
tokenizer_vocab_path=SPM_TOKENIZER_VOCAB_PATH,
vocab_size=SPM_TOKENIZER_VOCAB_SIZE,
train_size=800210, # TODO(kasimbeg): Update this
)
)


METADATA = {
'apply_one_hot_in_loss': True,
'shift_inputs': True,
'causal': True,
'pad_token': -1,
}
47 changes: 28 additions & 19 deletions init2winit/dataset_lib/wikitext_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class Tokenizer:
Attributes:
dictionary: Dictionary containing word-to-id and id-to-word mappings
lookup_table: tf.lookup.StaticHashTable for looking up token ids from words
"""

def __init__(self):
Expand All @@ -64,26 +65,34 @@ def __init__(self):
def train(self, dataset: tf.data.TextLineDataset):
"""Trains a Tokenizer from a TextLineDataset."""
# Add words to the dictionary
self.dictionary.add_word(UNKNOWN_TOKEN) # add default unknown token
for line in dataset:
words = line.numpy().split() + [EOS_TOKEN]
for word in words:
self.dictionary.add_word(word)

def tokenize(self, dataset: tf.data.TextLineDataset) -> tf.data.Dataset:
"""Tokenizes a TextLineDataset."""
idss = []
for line in dataset:
ids = []
words = line.numpy().split() + [EOS_TOKEN]
for word in words:
try:
ids.append(self.dictionary.word2idx[word])
except KeyError:
ids.append(self.dictionary.word2idx[UNKNOWN_TOKEN])
idss.append(ids)
ids = tf.concat(idss, 0)

tokenized_dataset = tf.data.Dataset.from_tensor_slices(ids)

return tokenized_dataset

# Make static vocabulary table for tf.data style tokenization
self.lookup_table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(
tf.constant(list(self.dictionary.word2idx.keys()), dtype=tf.string),
tf.constant(
list(self.dictionary.word2idx.values()), dtype=tf.int32
),
),
default_value=self.dictionary.word2idx[UNKNOWN_TOKEN],
)

def tokenize(self, input_tensor: tf.Tensor) -> tf.Tensor:
"""Tokenizes a tensor of UTF-8 strings.
Args:
input_tensor: A `RaggedTensor` or `Tensor` of UTF-8 strings with any
shape.
Returns:
A `RaggedTensor` or `Tensor` of tokenized text. The returned shape is
the shape of the input tensor.
"""
eos_tensor = tf.constant([EOS_TOKEN], dtype=tf.string)
input_tensor_split = tf.strings.split(input_tensor)
input_tensor_extended = tf.concat([input_tensor_split, eos_tensor], axis=-1)
return self.lookup_table.lookup(input_tensor_extended)

0 comments on commit 15ba961

Please sign in to comment.