-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
877097e
commit 4b2cc52
Showing
1 changed file
with
274 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,78 +1,283 @@ | ||
# Taken from llama code and lightly modified by karpathy and then taken from karpathy and not modified *yoink* | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | ||
""" | ||
Download, preprocess and serve the TinyStories dataset as a DataLoader. | ||
Yoinked from https://github.com/karpathy/llama2.c/blob/master/tinystories.py | ||
""" | ||
|
||
import os | ||
import struct | ||
import argparse | ||
import glob | ||
import json | ||
import os | ||
import random | ||
from typing import List | ||
from concurrent.futures import ProcessPoolExecutor | ||
from functools import partial | ||
|
||
import numpy as np | ||
import requests | ||
import sentencepiece as spm | ||
import torch | ||
import torch.distributed as dist | ||
from tqdm import tqdm | ||
|
||
from tokenizer import Tokenizer | ||
|
||
DATA_CACHE_DIR = "data" | ||
|
||
def download_file(url: str, fname: str, chunk_size=1024): | ||
"""Helper function to download a file from a given url""" | ||
resp = requests.get(url, stream=True) | ||
total = int(resp.headers.get("content-length", 0)) | ||
with open(fname, "wb") as file, tqdm( | ||
desc=fname, | ||
total=total, | ||
unit="iB", | ||
unit_scale=True, | ||
unit_divisor=1024, | ||
) as bar: | ||
for data in resp.iter_content(chunk_size=chunk_size): | ||
size = file.write(data) | ||
bar.update(size) | ||
|
||
|
||
def download(): | ||
"""Downloads the TinyStories dataset to DATA_CACHE_DIR""" | ||
os.makedirs(DATA_CACHE_DIR, exist_ok=True) | ||
|
||
# download the TinyStories dataset, unless it's already downloaded | ||
data_url = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz" | ||
data_filename = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data.tar.gz") | ||
if not os.path.exists(data_filename): | ||
print(f"Downloading {data_url} to {data_filename}...") | ||
download_file(data_url, data_filename) | ||
else: | ||
print(f"{data_filename} already exists, skipping download...") | ||
|
||
# unpack the tar.gz file into all the data shards (json files) | ||
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data") | ||
if not os.path.exists(data_dir): | ||
os.makedirs(data_dir, exist_ok=True) | ||
print(f"Unpacking {data_filename}...") | ||
os.system(f"tar -xzf {data_filename} -C {data_dir}") | ||
else: | ||
print(f"{data_dir} already exists, skipping unpacking...") | ||
|
||
# print a single example just for debugging and such | ||
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json"))) | ||
with open(shard_filenames[0], "r") as f: | ||
data = json.load(f) | ||
print("Download done.") | ||
print(f"Number of shards: {len(shard_filenames)}") | ||
print(f"Example story:\n{data[0]}") | ||
|
||
def train_vocab(vocab_size): | ||
""" | ||
Trains a custom sentencepiece tokenizer on the TinyStories dataset. | ||
The custom tokenizer files will be saved in DATA_CACHE_DIR/tok{N} directories, | ||
where N is the vocab size. This is also where the pretok .bin files will go. | ||
""" | ||
assert vocab_size > 0, "Vocab size must be positive" | ||
|
||
# output file prefix path for sentencepiece | ||
prefix = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}") | ||
|
||
# how many shards we'll use for vocab training, kept low for efficiency | ||
num_shards = 10 | ||
|
||
# 1) export a large chunk of text as a single text file tiny.txt | ||
tiny_file = os.path.join(DATA_CACHE_DIR, "tiny.txt") | ||
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data") | ||
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json"))) | ||
|
||
print(f"Writing temporary file {tiny_file} with {num_shards} shards...") | ||
with open(tiny_file, "w") as of: | ||
for shard in tqdm(shard_filenamTokenizeres[:num_shards]): | ||
with open(shard, "r") as f: | ||
data = json.load(f) | ||
for example in data: | ||
text = example["story"] | ||
text = text.strip() | ||
of.write(text + "\n") | ||
print(f"Size is: {os.path.getsize(tiny_file) / 1024 / 1024:.2f} MB") | ||
|
||
# 2) train the sentencepiece model | ||
print("Will now train the vocab...") | ||
spm.SentencePieceTrainer.train(input=tiny_file, | ||
model_prefix=prefix, | ||
model_type="bpe", | ||
vocab_size=vocab_size, | ||
self_test_sample_size=0, | ||
input_format="text", | ||
character_coverage=1.0, | ||
num_threads=os.cpu_count(), | ||
split_digits=True, | ||
allow_whitespace_only_pieces=True, | ||
byte_fallback=True, | ||
unk_surface=r" \342\201\207 ", | ||
normalization_rule_name="identity") | ||
|
||
from sentencepiece import SentencePieceProcessor | ||
|
||
TOKENIZER_MODEL = "tokenizer.model" # the llama sentencepiece tokenizer model | ||
|
||
class Tokenizer: | ||
def __init__(self, tokenizer_model=None): | ||
model_path = tokenizer_model if tokenizer_model else TOKENIZER_MODEL | ||
assert os.path.isfile(model_path), model_path | ||
self.sp_model = SentencePieceProcessor(model_file=model_path) | ||
self.model_path = model_path | ||
|
||
# BOS / EOS token IDs | ||
self.n_words: int = self.sp_model.vocab_size() | ||
self.bos_id: int = self.sp_model.bos_id() | ||
self.eos_id: int = self.sp_model.eos_id() | ||
self.pad_id: int = self.sp_model.pad_id() | ||
#print(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}") | ||
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() | ||
|
||
def encode(self, s: str, bos: bool, eos: bool) -> List[int]: | ||
assert type(s) is str | ||
t = self.sp_model.encode(s) | ||
if bos: | ||
t = [self.bos_id] + t | ||
if eos: | ||
t = t + [self.eos_id] | ||
return t | ||
|
||
def decode(self, t: List[int]) -> str: | ||
return self.sp_model.decode(t) | ||
|
||
def export(self): | ||
|
||
# get all the tokens (postprocessed) and their scores as floats | ||
tokens, scores = [], [] | ||
for i in range(self.n_words): | ||
|
||
# decode the token and light postprocessing | ||
t = self.sp_model.id_to_piece(i) | ||
s = self.sp_model.get_score(i) | ||
if i == self.bos_id: | ||
t = '\n<s>\n' | ||
elif i == self.eos_id: | ||
t = '\n</s>\n' | ||
t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace | ||
b = t.encode('utf-8') # bytes of this token, utf-8 encoded | ||
|
||
tokens.append(b) | ||
scores.append(s) | ||
|
||
# record the max token length | ||
max_token_length = max(len(t) for t in tokens) | ||
|
||
# write to a binary file | ||
# the tokenizer.bin file is the same as .model file, but .bin | ||
tokenizer_bin = self.model_path.replace('.model', '.bin') | ||
with open(tokenizer_bin, 'wb') as f: | ||
f.write(struct.pack("I", max_token_length)) | ||
for bytes, score in zip(tokens, scores): | ||
f.write(struct.pack("fI", score, len(bytes))) | ||
f.write(bytes) | ||
# 3) optional cleanup, ask the user if they'd like to delete tiny.txt | ||
dec = input(f"Delete the temporary file {tiny_file}? [y/N] ") | ||
if dec.lower() == "y": | ||
os.remove(tiny_file) | ||
print(f"Deleted {tiny_file}") | ||
|
||
print(f"Trained tokenizer is in {prefix}.model") | ||
print("Done.") | ||
|
||
|
||
def process_shard(args, vocab_size): | ||
shard_id, shard = args | ||
tokenizer_model = get_tokenizer_model_path(vocab_size) | ||
enc = Tokenizer(tokenizer_model) | ||
with open(shard, "r") as f: | ||
data = json.load(f) | ||
all_tokens = [] | ||
for example in tqdm(data, position=shard_id): | ||
text = example["story"] | ||
text = text.strip() # get rid of leading/trailing whitespace | ||
tokens = enc.encode(text, bos=True, eos=False) # encode the text, use BOS | ||
all_tokens.extend(tokens) | ||
# convert to uint16 nparray | ||
all_tokens = np.array(all_tokens, dtype=np.uint16) | ||
# calculate the output filename | ||
if vocab_size == 0: | ||
# if we're using Llama 2, just save the tokenized file in the same dir | ||
tokenized_filename = shard.replace(".json", ".bin") | ||
else: | ||
# save .bin files into a new tok{N} directory | ||
bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}") | ||
shard_basename = os.path.basename(shard) | ||
bin_basename = shard_basename.replace(".json", ".bin") | ||
tokenized_filename = os.path.join(bin_dir, bin_basename) | ||
# write the bytes | ||
with open(tokenized_filename, "wb") as f: | ||
f.write(all_tokens.tobytes()) | ||
# calculate the average sequence length (they are separated by BOS=1) | ||
avg_seq_len = all_tokens.size / ((all_tokens == 1).sum()) | ||
print(f"Saved {tokenized_filename}, average seqlen: {avg_seq_len:.2f}") | ||
|
||
|
||
def pretokenize(vocab_size): | ||
# iterate the shards and tokenize all of them one by one | ||
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data") | ||
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json"))) | ||
if vocab_size > 0: | ||
# .bin files will be saved into tok{N} directory, create it once here | ||
bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}") | ||
os.makedirs(bin_dir, exist_ok=True) | ||
|
||
# process all the shards in a process pool | ||
fun = partial(process_shard, vocab_size=vocab_size) | ||
with ProcessPoolExecutor(max_workers=8) as executor: | ||
executor.map(fun, enumerate(shard_filenames)) | ||
print("Done.") | ||
|
||
|
||
class PretokDataset(torch.utils.data.IterableDataset): | ||
"""Loads pretokenized examples from disk and yields them as PyTorch tensors.""" | ||
|
||
def __init__(self, split, max_seq_len, vocab_size, vocab_source): | ||
super().__init__() | ||
self.split = split | ||
self.max_seq_len = max_seq_len | ||
self.vocab_size = vocab_size | ||
self.vocab_source = vocab_source | ||
|
||
def __iter__(self): | ||
# get worker info within a DataLoader | ||
worker_info = torch.utils.data.get_worker_info() | ||
worker_id = worker_info.id if worker_info else 0 | ||
# get DDP rank info | ||
rank = dist.get_rank() if dist.is_initialized() else 0 | ||
# combine the worker_id and worker_rank to create a unique seed for rng | ||
seed = 42 + worker_id + 1337 * rank | ||
rng = random.Random(seed) | ||
print(f"Created a PretokDataset with rng seed {seed}") | ||
if self.vocab_source == "llama2": | ||
# the .bin files are right along the .json files | ||
bin_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data") | ||
shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin"))) | ||
elif self.vocab_source == "custom": | ||
# the .bin files are in tok{N} directory | ||
bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{self.vocab_size}") | ||
shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin"))) | ||
# train/test split. let's use only shard 0 for test split, rest train | ||
shard_filenames = shard_filenames[1:] if self.split == "train" else shard_filenames[:1] | ||
assert len(shard_filenames)>0, f"No bin files found in {bin_dir}" | ||
while True: | ||
rng.shuffle(shard_filenames) | ||
for shard in shard_filenames: | ||
# open the dataset for reading but keep it on disk with memmap | ||
m = np.memmap(shard, dtype=np.uint16, mode="r") | ||
num_batches = len(m) // self.max_seq_len | ||
num_batches -= 1 # drop the last partial batch | ||
assert num_batches > 0, "this shard is way too small? investigate." | ||
ixs = list(range(num_batches)) | ||
rng.shuffle(ixs) | ||
for ix in ixs: | ||
start = ix * self.max_seq_len | ||
end = start + self.max_seq_len + 1 | ||
# calling .astype will copy the data into a new numpy array, now in RAM | ||
chunk = torch.from_numpy((m[start:end]).astype(np.int64)) | ||
x = chunk[:-1] | ||
y = chunk[1:] | ||
yield x, y | ||
|
||
# ----------------------------------------------------------------------------- | ||
# public interface functions | ||
|
||
def get_tokenizer_model_path(vocab_size): | ||
""" | ||
Returns path to the sentencepiece tokenizer model for a given vocab size | ||
vocab_size = 0 designates the default Llama 2 tokenizer, in that case | ||
None is returned. | ||
""" | ||
if vocab_size == 0: | ||
return None | ||
else: | ||
return os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}.model") | ||
|
||
class Task: | ||
|
||
@staticmethod | ||
def iter_batches(batch_size, device, num_workers=0, **dataset_kwargs): | ||
ds = PretokDataset(**dataset_kwargs) | ||
dl = torch.utils.data.DataLoader( | ||
ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers | ||
) | ||
for x, y in dl: | ||
x = x.to(device, non_blocking=True) | ||
y = y.to(device, non_blocking=True) | ||
yield x, y | ||
|
||
# ----------------------------------------------------------------------------- | ||
# CLI for constructing the dataset | ||
|
||
if __name__ == "__main__": | ||
""" | ||
These stages are designed to be run in order. | ||
To tokenize data with the Llama 2 tokenizer: | ||
python tinystories.py download | ||
python tinystories.py pretokenize | ||
To tokenize data with a custom tokenizer we train ourselves with sentencepiece, e.g.: | ||
python tinystories.py download | ||
python tinystories.py train_vocab --vocab_size=2048 | ||
python tinystories.py pretokenize --vocab_size=2048 | ||
""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-t", "--tokenizer-model", type=str, help="optional path to custom tokenizer ") | ||
parser.add_argument("stage", type=str, choices=["download", "pretokenize", "train_vocab"]) | ||
parser.add_argument("--vocab_size", type=int, default=0, help="pretokenization vocab size. 0 = use Llama 2 tokenizer.") | ||
args = parser.parse_args() | ||
|
||
t = Tokenizer(args.tokenizer_model) | ||
t.export() | ||
# depending on the stage call the appropriate function | ||
if args.stage == "download": | ||
download() | ||
elif args.stage == "train_vocab": | ||
train_vocab(vocab_size=args.vocab_size) | ||
elif args.stage == "pretokenize": | ||
pretokenize(vocab_size=args.vocab_size) | ||
else: | ||
raise ValueError(f"Unknown stage {args.stage}") |