Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simple sweep experiments to compare Transformer experiment on wikitext-103 vs lm1b. #488

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions init2winit/dataset_lib/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from init2winit.dataset_lib import proteins
from init2winit.dataset_lib import small_image_datasets
from init2winit.dataset_lib import translate_wmt
from init2winit.dataset_lib import wikitext103
from init2winit.dataset_lib import wikitext2

_Dataset = collections.namedtuple(
Expand Down Expand Up @@ -100,6 +101,9 @@
'wikitext2':
_Dataset(wikitext2.get_wikitext2, wikitext2.DEFAULT_HPARAMS,
wikitext2.METADATA, None),
'wikitext103':
_Dataset(wikitext103.get_wikitext103, wikitext103.DEFAULT_HPARAMS,
wikitext2.METADATA, None),
}


Expand Down
57 changes: 57 additions & 0 deletions init2winit/dataset_lib/test_wikitext103.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# coding=utf-8
# Copyright 2022 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for init2winit.dataset_lib.wikitext103."""

from absl.testing import absltest
from init2winit.dataset_lib import wikitext103
from jax import random


class TestWikitext103(absltest.TestCase):
"""Unit tests for wikitext103.py."""

def test_vocab_size(self):
"""Test vocab size."""
wikitext103_hps = wikitext103.DEFAULT_HPARAMS
dataset = wikitext103.get_wikitext103(
shuffle_rng=random.PRNGKey(0),
batch_size=1,
eval_batch_size=1,
hps=wikitext103_hps,
)

tokens = set()

for batch in dataset.eval_train_epoch():
inputs = batch['inputs']
targets = batch['targets']

# pylint: disable=g-complex-comprehension
inputs_flat = [item for sublist in inputs for item in sublist]
targets_flat = [item for sublist in targets for item in sublist]

inputs_set = set(inputs_flat)
targets_set = set(targets_flat)

tokens = tokens.union(inputs_set, targets_set)

# Subtract 1 for the padding token
num_tokens = len(tokens) - 1

self.assertLen(num_tokens, wikitext103_hps.vocab_size)

if __name__ == '__main__':
absltest.main()
124 changes: 124 additions & 0 deletions init2winit/dataset_lib/wikitext103.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# coding=utf-8
# Copyright 2022 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module containing hyperparameters, metadata and dataset getter for Wikitext-2 dataset."""

import itertools

from init2winit.dataset_lib import data_utils
from init2winit.dataset_lib import wikitext103_input_pipeline as input_pipeline
from init2winit.dataset_lib.data_utils import Dataset
from init2winit.dataset_lib.wikitext2_input_pipeline import PAD_ID
import jax
from ml_collections.config_dict import config_dict
import numpy as np

VOCAB_SIZE = 267735

DEFAULT_HPARAMS = config_dict.ConfigDict(
dict(
sequence_length=128,
max_target_length=128,
max_eval_target_length=128,
input_shape=(128,),
output_shape=(VOCAB_SIZE,),
vocab_size=VOCAB_SIZE,
train_size=800210, # Number of sequences.
))


METADATA = {
'apply_one_hot_in_loss': True,
'shift_inputs': True,
'causal': True,
'pad_token': -1,
}


def add_weights_to_batch(batch, pad_id: int = PAD_ID):
"""Add weights for the input values so that paddings have 0 weight.

Args:
batch: Batch represented by dict containing 'inputs' and 'targets'.
pad_id: Value for 'inputs' that will have weight 0.

Returns:
batch with weights
"""
batch['weights'] = np.where(batch['inputs'] == pad_id, 0.0, 1.0)
return batch


def get_wikitext103(shuffle_rng,
batch_size: int,
eval_batch_size: int = None,
hps: config_dict.ConfigDict = None,
) -> Dataset:
"""Returns Wikitext-103 Dataset.

Args:
shuffle_rng: jax.random.PRNGKey
batch_size: training batch size
eval_batch_size: validation batch size
hps: Hyper parameters

Returns:
Dataset

Raises:
ValueError: If batch_size is not divisible by jax process count.
ValueError: If eval_batch_size is not divisible by jax process count.
"""
process_count = jax.process_count()

if batch_size % process_count != 0:
raise ValueError(
'process_count={} must divide batch_size={}.'.format(
process_count, batch_size))

if eval_batch_size % process_count != 0:
raise ValueError(
'process_count={} must divide batch_size={}.'.format(
process_count, batch_size))

if eval_batch_size is None:
eval_batch_size = batch_size

train_dataset, eval_train_dataset, valid_dataset, test_dataset = input_pipeline.get_wikitext103_dataset(
hps,
train_batch_size=batch_size,
valid_batch_size=eval_batch_size,
test_batch_size=eval_batch_size,
shuffle_seed=shuffle_rng[0],
)

def train_iterator_fn():
for batch in train_dataset:
yield add_weights_to_batch(data_utils.tf_to_numpy(batch))

def eval_train_epoch(num_batches=None):
for batch in itertools.islice(iter(eval_train_dataset), num_batches):
yield add_weights_to_batch(data_utils.tf_to_numpy(batch))

def valid_epoch(num_batches=None):
for batch in itertools.islice(iter(valid_dataset), num_batches):
yield add_weights_to_batch(data_utils.tf_to_numpy(batch))

def test_epoch(num_batches=None):
for batch in itertools.islice(iter(test_dataset), num_batches):
yield add_weights_to_batch(data_utils.tf_to_numpy(batch))

return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch,
test_epoch)
172 changes: 172 additions & 0 deletions init2winit/dataset_lib/wikitext103_input_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# coding=utf-8
# Copyright 2022 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module for processing wikitext-2 train, val and test datasets from raw text files to tokenized and batched tensorflow.data.Datasets."""

import os

from init2winit.dataset_lib import wikitext_tokenizer
from ml_collections.config_dict import config_dict
import tensorflow as tf

TRAIN_FILENAME = 'train.txt'
VALID_FILENAME = 'valid.txt'
TEST_FILENAME = 'test.txt'
SHUFFLE_BUFFER_SIZE = 1000_000
PAD_ID = -1


def get_trained_tokenizer(train_dataset: tf.data.Dataset,) -> tf.data.Dataset:
tokenizer = wikitext_tokenizer.Tokenizer()
tokenizer.train(train_dataset)
return tokenizer


def split_input_target(sequence):
input_sequence = sequence[:-1]
target_sequence = sequence[1:]
return {'inputs': input_sequence, 'targets': target_sequence}


def batch_with_padding(dataset: tf.data.Dataset,
batch_size,
padded_shapes=None,
padding_id=PAD_ID,
):
"""Batches a tf.data.Dataset and adds padding if len(dataset) not divisible by the batch size.

Args:
dataset: tf.data.Dataset
batch_size: batch size of resulting batched dataset
padded_shapes: shapes of the padded batches
padding_id: value for padding, for elements in new batch

Returns:

"""
batched_dataset = dataset.batch(batch_size, drop_remainder=False)

# tf.data.Dataset.padded.batch pads elements in the batch so we call it
# again with batch_size=1 to pad each element in original batch.
padded_batched_dataset = batched_dataset.padded_batch(
1, padded_shapes=padded_shapes, padding_values=padding_id)

# Remove extra dimension resulting from the batch_size=1.
padded_batched_dataset = padded_batched_dataset.unbatch()

return padded_batched_dataset


def get_wikitext103_dataset(
hps: config_dict.ConfigDict,
train_batch_size: int,
valid_batch_size: int,
test_batch_size: int,
shuffle_seed: int,
) -> tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
"""Returns wikitext-103 dataset.

Args:
hps: Dataset hyper parameters.
train_batch_size: Batch size for train iterations
valid_batch_size: Batch size for validation iterations
test_batch_size: Batch size for test iterations
shuffle_seed: seed for shuffling dataset sequences

Returns:
train_dataset, eval_train_dataset, valid_dataset, test_dataset
"""
train_path = os.path.join(DATA_DIR, TRAIN_FILENAME)
valid_path = os.path.join(DATA_DIR, VALID_FILENAME)
test_path = os.path.join(DATA_DIR, TEST_FILENAME)

# Get TextLineDataset from raw files
train_text_dataset = tf.data.TextLineDataset(train_path)
valid_text_dataset = tf.data.TextLineDataset(valid_path)
test_text_dataset = tf.data.TextLineDataset(test_path)

# Tokenize data
tokenizer = get_trained_tokenizer(train_text_dataset)

train_dataset_tokenized = tokenizer.tokenize(
train_text_dataset)
valid_dataset_tokenized = tokenizer.tokenize(
valid_text_dataset)
test_dataset_tokenized = tokenizer.tokenize(
test_text_dataset)

# Divide data in sequences of length sequence_length + 1, to contain inputs
# and corresponding targets
train_dataset_sequences = batch_with_padding(
train_dataset_tokenized,
hps.sequence_length + 1,
padded_shapes=hps.sequence_length + 1,
)
valid_dataset_sequences = batch_with_padding(
valid_dataset_tokenized,
hps.sequence_length + 1,
padded_shapes=hps.sequence_length + 1,
)
test_dataset_sequences = batch_with_padding(
test_dataset_tokenized,
hps.sequence_length + 1,
padded_shapes=hps.sequence_length + 1,
)

# Split the sequences into inputs and targets.
train_dataset_sequences = train_dataset_sequences.map(split_input_target)
valid_dataset_sequences = valid_dataset_sequences.map(split_input_target)
test_dataset_sequences = test_dataset_sequences.map(split_input_target)

# Copy the train_dataset_sequences to a non repeating dataset
eval_train_dataset_sequences = train_dataset_sequences

# Shuffle the train sequences.
train_dataset_sequences = train_dataset_sequences.shuffle(
SHUFFLE_BUFFER_SIZE, seed=shuffle_seed)

# Perform batching for training, validation and testing.
# Make training data repeat indefinitely.
train_dataset_sequences = train_dataset_sequences.repeat()
train_dataset = train_dataset_sequences.batch(
train_batch_size,
drop_remainder=False).prefetch(tf.data.experimental.AUTOTUNE)
# Use padded batches for eval_train, validation and test_datasets since the
# sequences do not repeat indefintely.
eval_train_dataset = batch_with_padding(
eval_train_dataset_sequences,
train_batch_size,
padded_shapes={
'inputs': (train_batch_size, None),
'targets': (train_batch_size, None)
}).prefetch(tf.data.experimental.AUTOTUNE)

valid_dataset = batch_with_padding(
valid_dataset_sequences,
valid_batch_size,
padded_shapes={
'inputs': (valid_batch_size, None),
'targets': (valid_batch_size, None)
}).prefetch(tf.data.experimental.AUTOTUNE)

test_dataset = batch_with_padding(
test_dataset_sequences,
test_batch_size,
padded_shapes={
'inputs': (test_batch_size, None),
'targets': (test_batch_size, None)
}).prefetch(tf.data.experimental.AUTOTUNE)

return train_dataset, eval_train_dataset, valid_dataset, test_dataset
3 changes: 1 addition & 2 deletions init2winit/dataset_lib/wikitext2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,13 @@
import numpy as np

VOCAB_SIZE = 33278
TEST_BATCH_SIZE = 16

DEFAULT_HPARAMS = config_dict.ConfigDict(
dict(
sequence_length=34,
max_target_length=34,
max_eval_target_length=34,
input_shape=(32,),
input_shape=(34,),
output_shape=(VOCAB_SIZE,),
vocab_size=VOCAB_SIZE,
# TODO(kasimbeg) : add vocab path after seperating out tokenizer
Expand Down
Loading