From 7af285aa76a257271c5488f92e99483ecd8e624f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 19 Dec 2022 13:19:20 -0800 Subject: [PATCH] Add simple sweep experiments to compare Transformer experiment on wikitext-103 vs lm1b. PiperOrigin-RevId: 496475812 --- init2winit/dataset_lib/datasets.py | 4 + init2winit/dataset_lib/test_wikitext103.py | 57 ++++++ init2winit/dataset_lib/wikitext103.py | 124 +++++++++++++ .../dataset_lib/wikitext103_input_pipeline.py | 172 ++++++++++++++++++ init2winit/dataset_lib/wikitext2.py | 3 +- .../dataset_lib/wikitext2_input_pipeline.py | 4 +- ...xt2_tokenizer.py => wikitext_tokenizer.py} | 6 +- init2winit/model_lib/lstm_lm.py | 4 - init2winit/utils.py | 2 +- 9 files changed, 366 insertions(+), 10 deletions(-) create mode 100644 init2winit/dataset_lib/test_wikitext103.py create mode 100644 init2winit/dataset_lib/wikitext103.py create mode 100644 init2winit/dataset_lib/wikitext103_input_pipeline.py rename init2winit/dataset_lib/{wikitext2_tokenizer.py => wikitext_tokenizer.py} (94%) diff --git a/init2winit/dataset_lib/datasets.py b/init2winit/dataset_lib/datasets.py index 618e1841..6a856e19 100644 --- a/init2winit/dataset_lib/datasets.py +++ b/init2winit/dataset_lib/datasets.py @@ -29,6 +29,7 @@ from init2winit.dataset_lib import proteins from init2winit.dataset_lib import small_image_datasets from init2winit.dataset_lib import translate_wmt +from init2winit.dataset_lib import wikitext103 from init2winit.dataset_lib import wikitext2 _Dataset = collections.namedtuple( @@ -100,6 +101,9 @@ 'wikitext2': _Dataset(wikitext2.get_wikitext2, wikitext2.DEFAULT_HPARAMS, wikitext2.METADATA, None), + 'wikitext103': + _Dataset(wikitext103.get_wikitext103, wikitext103.DEFAULT_HPARAMS, + wikitext2.METADATA, None), } diff --git a/init2winit/dataset_lib/test_wikitext103.py b/init2winit/dataset_lib/test_wikitext103.py new file mode 100644 index 00000000..465fd505 --- /dev/null +++ b/init2winit/dataset_lib/test_wikitext103.py @@ -0,0 +1,57 @@ +# coding=utf-8 +# Copyright 2022 The init2winit Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for init2winit.dataset_lib.wikitext103.""" + +from absl.testing import absltest +from init2winit.dataset_lib import wikitext103 +from jax import random + + +class TestWikitext103(absltest.TestCase): + """Unit tests for wikitext103.py.""" + + def test_vocab_size(self): + """Test vocab size.""" + wikitext103_hps = wikitext103.DEFAULT_HPARAMS + dataset = wikitext103.get_wikitext103( + shuffle_rng=random.PRNGKey(0), + batch_size=1, + eval_batch_size=1, + hps=wikitext103_hps, + ) + + tokens = set() + + for batch in dataset.eval_train_epoch(): + inputs = batch['inputs'] + targets = batch['targets'] + + # pylint: disable=g-complex-comprehension + inputs_flat = [item for sublist in inputs for item in sublist] + targets_flat = [item for sublist in targets for item in sublist] + + inputs_set = set(inputs_flat) + targets_set = set(targets_flat) + + tokens = tokens.union(inputs_set, targets_set) + + # Subtract 1 for the padding token + num_tokens = len(tokens) - 1 + + self.assertLen(num_tokens, wikitext103_hps.vocab_size) + +if __name__ == '__main__': + absltest.main() diff --git a/init2winit/dataset_lib/wikitext103.py b/init2winit/dataset_lib/wikitext103.py new file mode 100644 index 00000000..824d8916 --- /dev/null +++ b/init2winit/dataset_lib/wikitext103.py @@ -0,0 +1,124 @@ +# coding=utf-8 +# Copyright 2022 The init2winit Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module containing hyperparameters, metadata and dataset getter for Wikitext-2 dataset.""" + +import itertools + +from init2winit.dataset_lib import data_utils +from init2winit.dataset_lib import wikitext103_input_pipeline as input_pipeline +from init2winit.dataset_lib.data_utils import Dataset +from init2winit.dataset_lib.wikitext2_input_pipeline import PAD_ID +import jax +from ml_collections.config_dict import config_dict +import numpy as np + +VOCAB_SIZE = 267735 + +DEFAULT_HPARAMS = config_dict.ConfigDict( + dict( + sequence_length=128, + max_target_length=128, + max_eval_target_length=128, + input_shape=(128,), + output_shape=(VOCAB_SIZE,), + vocab_size=VOCAB_SIZE, + train_size=800210, # Number of sequences. + )) + + +METADATA = { + 'apply_one_hot_in_loss': True, + 'shift_inputs': True, + 'causal': True, + 'pad_token': -1, +} + + +def add_weights_to_batch(batch, pad_id: int = PAD_ID): + """Add weights for the input values so that paddings have 0 weight. + + Args: + batch: Batch represented by dict containing 'inputs' and 'targets'. + pad_id: Value for 'inputs' that will have weight 0. + + Returns: + batch with weights + """ + batch['weights'] = np.where(batch['inputs'] == pad_id, 0.0, 1.0) + return batch + + +def get_wikitext103(shuffle_rng, + batch_size: int, + eval_batch_size: int = None, + hps: config_dict.ConfigDict = None, + ) -> Dataset: + """Returns Wikitext-103 Dataset. + + Args: + shuffle_rng: jax.random.PRNGKey + batch_size: training batch size + eval_batch_size: validation batch size + hps: Hyper parameters + + Returns: + Dataset + + Raises: + ValueError: If batch_size is not divisible by jax process count. + ValueError: If eval_batch_size is not divisible by jax process count. + """ + process_count = jax.process_count() + + if batch_size % process_count != 0: + raise ValueError( + 'process_count={} must divide batch_size={}.'.format( + process_count, batch_size)) + + if eval_batch_size % process_count != 0: + raise ValueError( + 'process_count={} must divide batch_size={}.'.format( + process_count, batch_size)) + + if eval_batch_size is None: + eval_batch_size = batch_size + + train_dataset, eval_train_dataset, valid_dataset, test_dataset = input_pipeline.get_wikitext103_dataset( + hps, + train_batch_size=batch_size, + valid_batch_size=eval_batch_size, + test_batch_size=eval_batch_size, + shuffle_seed=shuffle_rng[0], + ) + + def train_iterator_fn(): + for batch in train_dataset: + yield add_weights_to_batch(data_utils.tf_to_numpy(batch)) + + def eval_train_epoch(num_batches=None): + for batch in itertools.islice(iter(eval_train_dataset), num_batches): + yield add_weights_to_batch(data_utils.tf_to_numpy(batch)) + + def valid_epoch(num_batches=None): + for batch in itertools.islice(iter(valid_dataset), num_batches): + yield add_weights_to_batch(data_utils.tf_to_numpy(batch)) + + def test_epoch(num_batches=None): + for batch in itertools.islice(iter(test_dataset), num_batches): + yield add_weights_to_batch(data_utils.tf_to_numpy(batch)) + + return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, + test_epoch) diff --git a/init2winit/dataset_lib/wikitext103_input_pipeline.py b/init2winit/dataset_lib/wikitext103_input_pipeline.py new file mode 100644 index 00000000..bb0bdea6 --- /dev/null +++ b/init2winit/dataset_lib/wikitext103_input_pipeline.py @@ -0,0 +1,172 @@ +# coding=utf-8 +# Copyright 2022 The init2winit Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for processing wikitext-2 train, val and test datasets from raw text files to tokenized and batched tensorflow.data.Datasets.""" + +import os + +from init2winit.dataset_lib import wikitext_tokenizer +from ml_collections.config_dict import config_dict +import tensorflow as tf + +TRAIN_FILENAME = 'train.txt' +VALID_FILENAME = 'valid.txt' +TEST_FILENAME = 'test.txt' +SHUFFLE_BUFFER_SIZE = 1000_000 +PAD_ID = -1 + + +def get_trained_tokenizer(train_dataset: tf.data.Dataset,) -> tf.data.Dataset: + tokenizer = wikitext_tokenizer.Tokenizer() + tokenizer.train(train_dataset) + return tokenizer + + +def split_input_target(sequence): + input_sequence = sequence[:-1] + target_sequence = sequence[1:] + return {'inputs': input_sequence, 'targets': target_sequence} + + +def batch_with_padding(dataset: tf.data.Dataset, + batch_size, + padded_shapes=None, + padding_id=PAD_ID, + ): + """Batches a tf.data.Dataset and adds padding if len(dataset) not divisible by the batch size. + + Args: + dataset: tf.data.Dataset + batch_size: batch size of resulting batched dataset + padded_shapes: shapes of the padded batches + padding_id: value for padding, for elements in new batch + + Returns: + + """ + batched_dataset = dataset.batch(batch_size, drop_remainder=False) + + # tf.data.Dataset.padded.batch pads elements in the batch so we call it + # again with batch_size=1 to pad each element in original batch. + padded_batched_dataset = batched_dataset.padded_batch( + 1, padded_shapes=padded_shapes, padding_values=padding_id) + + # Remove extra dimension resulting from the batch_size=1. + padded_batched_dataset = padded_batched_dataset.unbatch() + + return padded_batched_dataset + + +def get_wikitext103_dataset( + hps: config_dict.ConfigDict, + train_batch_size: int, + valid_batch_size: int, + test_batch_size: int, + shuffle_seed: int, +) -> tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]: + """Returns wikitext-103 dataset. + + Args: + hps: Dataset hyper parameters. + train_batch_size: Batch size for train iterations + valid_batch_size: Batch size for validation iterations + test_batch_size: Batch size for test iterations + shuffle_seed: seed for shuffling dataset sequences + + Returns: + train_dataset, eval_train_dataset, valid_dataset, test_dataset + """ + train_path = os.path.join(DATA_DIR, TRAIN_FILENAME) + valid_path = os.path.join(DATA_DIR, VALID_FILENAME) + test_path = os.path.join(DATA_DIR, TEST_FILENAME) + + # Get TextLineDataset from raw files + train_text_dataset = tf.data.TextLineDataset(train_path) + valid_text_dataset = tf.data.TextLineDataset(valid_path) + test_text_dataset = tf.data.TextLineDataset(test_path) + + # Tokenize data + tokenizer = get_trained_tokenizer(train_text_dataset) + + train_dataset_tokenized = tokenizer.tokenize( + train_text_dataset) + valid_dataset_tokenized = tokenizer.tokenize( + valid_text_dataset) + test_dataset_tokenized = tokenizer.tokenize( + test_text_dataset) + + # Divide data in sequences of length sequence_length + 1, to contain inputs + # and corresponding targets + train_dataset_sequences = batch_with_padding( + train_dataset_tokenized, + hps.sequence_length + 1, + padded_shapes=hps.sequence_length + 1, + ) + valid_dataset_sequences = batch_with_padding( + valid_dataset_tokenized, + hps.sequence_length + 1, + padded_shapes=hps.sequence_length + 1, + ) + test_dataset_sequences = batch_with_padding( + test_dataset_tokenized, + hps.sequence_length + 1, + padded_shapes=hps.sequence_length + 1, + ) + + # Split the sequences into inputs and targets. + train_dataset_sequences = train_dataset_sequences.map(split_input_target) + valid_dataset_sequences = valid_dataset_sequences.map(split_input_target) + test_dataset_sequences = test_dataset_sequences.map(split_input_target) + + # Copy the train_dataset_sequences to a non repeating dataset + eval_train_dataset_sequences = train_dataset_sequences + + # Shuffle the train sequences. + train_dataset_sequences = train_dataset_sequences.shuffle( + SHUFFLE_BUFFER_SIZE, seed=shuffle_seed) + + # Perform batching for training, validation and testing. + # Make training data repeat indefinitely. + train_dataset_sequences = train_dataset_sequences.repeat() + train_dataset = train_dataset_sequences.batch( + train_batch_size, + drop_remainder=False).prefetch(tf.data.experimental.AUTOTUNE) + # Use padded batches for eval_train, validation and test_datasets since the + # sequences do not repeat indefintely. + eval_train_dataset = batch_with_padding( + eval_train_dataset_sequences, + train_batch_size, + padded_shapes={ + 'inputs': (train_batch_size, None), + 'targets': (train_batch_size, None) + }).prefetch(tf.data.experimental.AUTOTUNE) + + valid_dataset = batch_with_padding( + valid_dataset_sequences, + valid_batch_size, + padded_shapes={ + 'inputs': (valid_batch_size, None), + 'targets': (valid_batch_size, None) + }).prefetch(tf.data.experimental.AUTOTUNE) + + test_dataset = batch_with_padding( + test_dataset_sequences, + test_batch_size, + padded_shapes={ + 'inputs': (test_batch_size, None), + 'targets': (test_batch_size, None) + }).prefetch(tf.data.experimental.AUTOTUNE) + + return train_dataset, eval_train_dataset, valid_dataset, test_dataset diff --git a/init2winit/dataset_lib/wikitext2.py b/init2winit/dataset_lib/wikitext2.py index db513c93..8fbdc358 100644 --- a/init2winit/dataset_lib/wikitext2.py +++ b/init2winit/dataset_lib/wikitext2.py @@ -26,14 +26,13 @@ import numpy as np VOCAB_SIZE = 33278 -TEST_BATCH_SIZE = 16 DEFAULT_HPARAMS = config_dict.ConfigDict( dict( sequence_length=34, max_target_length=34, max_eval_target_length=34, - input_shape=(32,), + input_shape=(34,), output_shape=(VOCAB_SIZE,), vocab_size=VOCAB_SIZE, # TODO(kasimbeg) : add vocab path after seperating out tokenizer diff --git a/init2winit/dataset_lib/wikitext2_input_pipeline.py b/init2winit/dataset_lib/wikitext2_input_pipeline.py index f56cafb3..3b68860b 100644 --- a/init2winit/dataset_lib/wikitext2_input_pipeline.py +++ b/init2winit/dataset_lib/wikitext2_input_pipeline.py @@ -17,7 +17,7 @@ import os -from init2winit.dataset_lib import wikitext2_tokenizer +from init2winit.dataset_lib import wikitext_tokenizer from ml_collections.config_dict import config_dict import tensorflow as tf @@ -29,7 +29,7 @@ def get_trained_tokenizer(train_dataset: tf.data.Dataset,) -> tf.data.Dataset: - tokenizer = wikitext2_tokenizer.Tokenizer() + tokenizer = wikitext_tokenizer.Tokenizer() tokenizer.train(train_dataset) return tokenizer diff --git a/init2winit/dataset_lib/wikitext2_tokenizer.py b/init2winit/dataset_lib/wikitext_tokenizer.py similarity index 94% rename from init2winit/dataset_lib/wikitext2_tokenizer.py rename to init2winit/dataset_lib/wikitext_tokenizer.py index 105f19e7..8a4ee2ab 100644 --- a/init2winit/dataset_lib/wikitext2_tokenizer.py +++ b/init2winit/dataset_lib/wikitext_tokenizer.py @@ -14,6 +14,10 @@ # limitations under the License. """Contains Tokenizer class for word level tokenization. + +Note that the current tokenization workflow is not yet optimized for time and +memory yet. + """ import tensorflow as tf @@ -68,7 +72,7 @@ def tokenize(self, dataset: tf.data.TextLineDataset) -> tf.data.Dataset: idss = [] for line in dataset: ids = [] - words = line.numpy().split() + [b''] + words = line.numpy().split() + [EOS_TOKEN] for word in words: try: ids.append(self.dictionary.word2idx[word]) diff --git a/init2winit/model_lib/lstm_lm.py b/init2winit/model_lib/lstm_lm.py index a95e5bda..276d16e2 100644 --- a/init2winit/model_lib/lstm_lm.py +++ b/init2winit/model_lib/lstm_lm.py @@ -117,10 +117,6 @@ class LSTMModel(base_model.BaseModel): def evaluate_batch(self, params, batch_stats, batch): """Evaluates metrics on the given batch. - This method uses the class method apply_on_batch instead of - flax_module.apply because the flax_module. The apply_on_batch method - handles the 'length' input into the flax_nlp LSTM inner module. - We use the CLU metrics library to evaluate the metrics, and we require that each metric_fn in metrics_bundle has the API: metric_fn(logits, targets, weights), including the argument names. diff --git a/init2winit/utils.py b/init2winit/utils.py index abc3ea4f..983090d2 100644 --- a/init2winit/utils.py +++ b/init2winit/utils.py @@ -339,7 +339,7 @@ def tabulate_model(model, hps): tabulate_fn = nn.tabulate(model.flax_module, jax.random.PRNGKey(0), console_kwargs={'force_terminal': False, 'force_jupyter': False, - 'width': 120}, + 'width': 240}, ) fake_inputs_hps = copy.copy(hps) fake_inputs_hps.batch_size = 2