diff --git a/3.test_cases/neuronx-distributed/README.md b/3.test_cases/neuronx-distributed/README.md new file mode 100644 index 00000000..ff717d82 --- /dev/null +++ b/3.test_cases/neuronx-distributed/README.md @@ -0,0 +1,173 @@ +# NeuronX distributed test cases + + + +MPT are GPT-style models in [llm-foundry](https://github.com/mosaicml/llm-foundry/tree/main) with some special features -- [Flash Attention](https://arxiv.org/abs/2205.14135) for efficiency, [ALiBi](https://arxiv.org/abs/2108.12409) for context length extrapolation, and stability improvements to mitigate loss spikes. + +This project contains: + +* AWS optimized [llm-foundry](https://github.com/mosaicml/llm-foundry/tree/main) container image. +* Slurm scripts for the [c4 dataset](https://huggingface.co/datasets/c4) preparation and multi-node distributed training. + +## 1. Preparation + +This guide assumes that you have the following: + +* A functional Slurm cluster on AWS. +* Docker, [Pyxis](https://github.com/NVIDIA/pyxis) and [Enroot](https://github.com/NVIDIA/enroot) installed. +* An FSx for Lustre filesystem mounted on `/fsx`. + +We recommend that you setup a Slurm cluster using the templates in the architectures [directory](../../1.architectures). Before creating the Slurm cluster, you need to setup the following environment variables: + +```bash +export APPS_PATH=/apps +export ENROOT_IMAGE=$APPS_PATH/llm-foundry.sqsh +export FSX_PATH=/fsx +export DATA_PATH=$FSX_PATH/c4-dataset +export TEST_CASE_PATH=${HOME}/3.MPT # where you copy the test case or set to your test case path +cd $TEST_CASE_PATH +``` + +then follow the detailed instructions [here](../../1.architectures/2.aws-parallelcluster/README.md). + +## 2. Build the container + +Before running training jobs, you need to use an [Enroot](https://github.com/NVIDIA/enroot) container to retrieve and preprocess the input data. Below are the steps you need to follow: + +1. Copy the test case files to your cluster. You will need `0.llm-foundry.Dockerfile`, +2. Build the Docker image with the command below in this directory. + + ```bash + docker build -t llm-foundry -f 0.llm-foundry.Dockerfile . + ``` + +3. Once the Docker image is built, you can check if it is present with `docker images`. You should see an output similar to this one: + + ```bash + REPOSITORY TAG IMAGE ID CREATED SIZE + llm-foundry latest a964fb32cd53 2 weeks ago 23.6GB + ... + ``` + +4. Convert the Docker image to a squash file with the command below. + + ```bash + enroot import -o ${ENROOT_IMAGE} dockerd://llm-foundry:latest + ``` + + The file will be stored in the `/apps` directory (default). The output should look as below. + + ```bash + [INFO] Fetching image + + 36a8c752c28a2db543d2a632a3fc1fcbd5789a6f3d45b9d3a24632420dedcfa8 + + [INFO] Extracting image content... + [INFO] Creating squashfs filesystem... + + Parallel mksquashfs: Using 32 processors + Creating 4.0 filesystem on /apps/llm-foundry.sqsh, block size 131072. + [========================================================================================================================================================================================================================-] 291068/291068 100% + + Exportable Squashfs 4.0 filesystem, gzip compressed, data block size 131072 + uncompressed data, uncompressed metadata, uncompressed fragments, uncompressed xattrs + duplicates are not removed + ... + ``` + +It will take around 5 minutes to convert the container image from Docker to the Enroot format. Once done proceed to the next stage. + +For ease of testing we've included a `Makefile` that automatically builds and imports the latest image. To run this, execute `make` or you can individually specify `make build` to build the Docker image, `make clean` to remove the squash file and `make import` to import the Dockerfile into enroot squash file. + +## 3. Run the processing job + +You need to retrieve input data and preprocess it before running the training job. + +1. Run a preprocessing job by submitting the script `1.c4-preprocess.sbatch` to Slurm. The command will return the Slurm Job ID. You can use `squeue` to consult the status of your jobs. + + ```bash + sbatch 1.c4-preprocess.sbatch + ``` + + It will create the streaming dataset for composer library using C4 dataset in `/fsx/c4-dataset` (default). + +2. You see a new file in your current working directory called `c4-preprocess_XY.out` where `XY` corresponds the Slurm job ID. This is your output file and will capture the `STDOUT` and `STDERR` from your job. You can check how it progresses via the command `tail -f c4-preprocess_XY.out` with the correct job ID instead of `XY`. If running successfully, the job will generate an output similar to the except below. + + ```console + Downloading (…)okenizer_config.json: 100%|██████████| 156/156 [00:00<00:00, 1.09MB/s] + ... + Downloading metadata: 100%|██████████| 2.40M/2.40M [00:01<00:00, 2.05MB/s] + ... + train_small: 32%|███▏ | 31745/100000 [01:51<00:19, 3538.83it/s] + ... + val_small: 100%|██████████| 10000/10000 [00:19<00:00, 514.19it/s] + ``` + + Please be aware that this job downloads the tokenizer on demand (if it's not available under `./EleutherAI/gpt-neox-20b`), after which the tokenizer will be cached under `$HOME/.cache/huggingface`, and the `$HOME` directory is an NFS filesystem shared by the head node. Please consult the [HuggingFace cache management](https://huggingface.co/docs/datasets/cache) document to learn more about fine-grained control of the HuggingFace cache. + +3. After the job completed, check `/fsx/c4-dataset` (default) which will contain a structure similar as below + + ```console + /fsx/c4-dataset/ + ├── train_small + │ ├── index.json + │ ├── shard.00000.mds + │ ├── shard.00001.mds + │ ├── shard.00002.mds + ... + │ ├── shard.00023.mds + │ └── shard.00024.mds + └── val_small + ├── index.json + ├── shard.00000.mds + ├── shard.00001.mds + └── shard.00002.mds + ``` + +Once preprocessing is done, you will run a training job in the next stage. + +## 4. Distributed training of MPT + +Now that the data is preprocessed, we will pretrain a MPT model with [Mosaic Composer](https://github.com/mosaicml/composer). + +1. Run a training job by submitting script `2.train-mpt-manual-distributed.sbatch` to Slurm via `sbatch` as shown below. + + ```bash + sbatch 2.train-mpt-manual-distributed.sbatch + ``` +by default it runs `mpt-7b` model. You can specify model to be trained as: + ```bash + sbatch 2.train-mpt-manual-distributed.sbatch mpt-30b + ``` + +2. When the training job completes successfully, it should produce a log output similar to the below in the `logs/` directory of `$TEST_CASE_PATH`. + +```console +... +0: [batch=1/300000000]: +0: Train time/epoch: 0 +0: Train time/batch: 0 +0: Train time/sample: 0 +0: Train time/batch_in_epoch: 0 +0: Train time/sample_in_epoch: 0 +0: Train time/token: 0 +0: Train time/token_in_epoch: 0 +0: Train memory/allocated_mem: 3.6287 +0: Train memory/active_mem: 3.6287 +0: Train memory/inactive_mem: 2.7844 +0: Train memory/reserved_mem: 20.9650 +0: Train memory/alloc_retries: 0 +0: Train trainer/device_train_microbatch_size: 8 +0: Train loss/train/total: 12.0000 +0: Train metrics/train/LanguageCrossEntropy: 12.0000 +0: Train metrics/train/LanguagePerplexity: 162754.5000 +0: Train time/train: 0.0037 +0: Train time/val: 0.0000 +... +``` + +## 5. Authors / Reviewers + +* [A] Keita Watanabe - mlkeita@ +* [R] Pierre-Yves Aquilanti - pierreya@ +* [R] Verdi March - marcverd@ diff --git a/3.test_cases/neuronx-distributed/mingpt/0.cpu.py b/3.test_cases/neuronx-distributed/mingpt/0.cpu.py new file mode 100644 index 00000000..3345f63d --- /dev/null +++ b/3.test_cases/neuronx-distributed/mingpt/0.cpu.py @@ -0,0 +1,54 @@ +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from torch.nn import functional as F +from mingpt.model import GPT +from mingpt.datasets import SortDataset +from mingpt.configs import TrainConfig +from mingpt.utils import evaluate + +# create train and test dataset +length = 6 +num_digits = 3 +train_dataset = SortDataset('train', length, num_digits) +test_dataset = SortDataset('test', length, num_digits) +train_config = TrainConfig.get_default_config() +train_loader = DataLoader( + train_dataset, + batch_size=train_config.batch_size, +) +test_loader = DataLoader( + test_dataset, + batch_size=train_config.batch_size, +) + +# create a GPT instance +model_config = GPT.get_default_config() +model_config.model_type = 'gpt-nano' +model_config.vocab_size = train_dataset.get_vocab_size() +model_config.block_size = train_dataset.get_block_size() +model = GPT(model_config) +optimizer = model.configure_optimizers(train_config) + +model.train() +pbar = tqdm(train_loader) +for idx, (x, y) in enumerate(pbar): + optimizer.zero_grad() + # forward the model + logits = model(x) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + y.view(-1), + ignore_index=-1 + ) + # backprop and update the parameters + loss.backward() + optimizer.step() + pbar.set_description(f"Iteration: {idx}, train loss: {loss.item():.5f}") + +model.eval() +print("Evaluate performance with train_loader") +evaluate(model, train_loader, length, max_batches=50) +print("Evaluate performance with test_loader") +evaluate(model, test_loader, length, max_batches=50) \ No newline at end of file diff --git a/3.test_cases/neuronx-distributed/mingpt/1.neuron.py b/3.test_cases/neuronx-distributed/mingpt/1.neuron.py new file mode 100644 index 00000000..2b1b2497 --- /dev/null +++ b/3.test_cases/neuronx-distributed/mingpt/1.neuron.py @@ -0,0 +1,66 @@ +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from torch.nn import functional as F +import torch_xla.core.xla_model as xm +# XLA imports for parallel loader and multi-processing +import torch_xla.distributed.parallel_loader as pl + +from mingpt.model import GPT +from mingpt.datasets import SortDataset +from mingpt.configs import TrainConfig +from mingpt.utils import evaluate + + +device = 'xla' +# create train and test dataset +length = 6 +num_digits = 3 +train_dataset = SortDataset('train', length, num_digits) +test_dataset = SortDataset('test', length, num_digits) +train_config = TrainConfig.get_default_config() +train_loader = DataLoader( + train_dataset, + batch_size=train_config.batch_size, +) +test_loader = DataLoader( + test_dataset, + batch_size=train_config.batch_size, +) +# We wrap the dataloader with MpDeviceLoader. This dataloader should take +# care of copying the tensors to device +train_loader = pl.MpDeviceLoader(train_loader, device) +test_loader = pl.MpDeviceLoader(test_loader, device) + +# create a GPT instance +model_config = GPT.get_default_config() +model_config.model_type = 'gpt-nano' +model_config.vocab_size = train_dataset.get_vocab_size() +model_config.block_size = train_dataset.get_block_size() +model = GPT(model_config) +model = model.to(device) +optimizer = model.configure_optimizers(train_config) + + +model.train() +pbar = tqdm(train_loader) +for idx, (x, y) in enumerate(pbar): + optimizer.zero_grad() + # forward the model + logits = model(x) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + y.view(-1), + ignore_index=-1 + ) + # backprop and update the parameters + loss.backward() + xm.optimizer_step(optimizer) # XLA MP: performs grad allreduce and optimizer step + pbar.set_description(f"Iteration: {idx}, train loss: {loss.item():.5f}") + +model.eval() +print("Evaluate performance with train_loader") +evaluate(model, train_loader, length, max_batches=50) +print("Evaluate performance with test_loader") +evaluate(model, test_loader, length, max_batches=50) \ No newline at end of file diff --git a/3.test_cases/neuronx-distributed/mingpt/2.ddp-neuron.py b/3.test_cases/neuronx-distributed/mingpt/2.ddp-neuron.py new file mode 100644 index 00000000..752fdae1 --- /dev/null +++ b/3.test_cases/neuronx-distributed/mingpt/2.ddp-neuron.py @@ -0,0 +1,70 @@ +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from torch.nn import functional as F +import torch_xla.core.xla_model as xm +# XLA imports for parallel loader and multi-processing +import torch_xla.distributed.parallel_loader as pl + +from mingpt.model import GPT +from mingpt.datasets import SortDataset +from mingpt.configs import TrainConfig +from mingpt.utils import evaluate + +torch.distributed.init_process_group('xla') +device = xm.xla_device() +rank = xm.get_ordinal() +world_size = xm.xrt_world_size() + +print(f'rank: {rank}, world size {world_size}') +# create train and test dataset +length = 6 +num_digits = 3 +train_dataset = SortDataset('train', length, num_digits) +test_dataset = SortDataset('test', length, num_digits) +train_config = TrainConfig.get_default_config() +train_loader = DataLoader( + train_dataset, + batch_size=train_config.batch_size, +) +test_loader = DataLoader( + test_dataset, + batch_size=train_config.batch_size, +) +# We wrap the dataloader with MpDeviceLoader. This dataloader should take +# care of copying the tensors to device +train_loader = pl.MpDeviceLoader(train_loader, device) +test_loader = pl.MpDeviceLoader(test_loader, device) + +# create a GPT instance +model_config = GPT.get_default_config() +model_config.model_type = 'gpt-nano' +model_config.vocab_size = train_dataset.get_vocab_size() +model_config.block_size = train_dataset.get_block_size() +model = GPT(model_config) +model = model.to(device) +optimizer = model.configure_optimizers(train_config) + + +model.train() +pbar = tqdm(train_loader) +for idx, (x, y) in enumerate(pbar): + optimizer.zero_grad() + # forward the model + logits = model(x) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + y.view(-1), + ignore_index=-1 + ) + # backprop and update the parameters + loss.backward() + xm.optimizer_step(optimizer) # XLA MP: performs grad allreduce and optimizer step + pbar.set_description(f"Iteration: {idx}, train loss: {loss.item():.5f}") + +model.eval() +print("Evaluate performance with train_loader") +evaluate(model, train_loader, length, max_batches=50) +print("Evaluate performance with test_loader") +evaluate(model, test_loader, length, max_batches=50) \ No newline at end of file diff --git a/3.test_cases/neuronx-distributed/mingpt/LICENSE b/3.test_cases/neuronx-distributed/mingpt/LICENSE new file mode 100644 index 00000000..3d899601 --- /dev/null +++ b/3.test_cases/neuronx-distributed/mingpt/LICENSE @@ -0,0 +1,7 @@ +The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/3.test_cases/neuronx-distributed/mingpt/README.md b/3.test_cases/neuronx-distributed/mingpt/README.md new file mode 100644 index 00000000..8fa3ea54 --- /dev/null +++ b/3.test_cases/neuronx-distributed/mingpt/README.md @@ -0,0 +1,5 @@ +# minGPT test case + +This test case is an educational sample that guide you through how to construct distributed training codes using NeuronX distributed. + + diff --git a/3.test_cases/neuronx-distributed/mingpt/demo.ipynb b/3.test_cases/neuronx-distributed/mingpt/demo.ipynb new file mode 100644 index 00000000..4e74622e --- /dev/null +++ b/3.test_cases/neuronx-distributed/mingpt/demo.ipynb @@ -0,0 +1,331 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A cute little demo showing the simplest usage of minGPT. Configured to run fine on Macbook Air in like a minute." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch.utils.data import Dataset\n", + "from torch.utils.data.dataloader import DataLoader\n", + "from mingpt.utils import set_seed\n", + "set_seed(3407)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import pickle\n", + "\n", + "class SortDataset(Dataset):\n", + " \"\"\" \n", + " Dataset for the Sort problem. E.g. for problem length 6:\n", + " Input: 0 0 2 1 0 1 -> Output: 0 0 0 1 1 2\n", + " Which will feed into the transformer concatenated as:\n", + " input: 0 0 2 1 0 1 0 0 0 1 1\n", + " output: I I I I I 0 0 0 1 1 2\n", + " where I is \"ignore\", as the transformer is reading the input sequence\n", + " \"\"\"\n", + "\n", + " def __init__(self, split, length=6, num_digits=3):\n", + " assert split in {'train', 'test'}\n", + " self.split = split\n", + " self.length = length\n", + " self.num_digits = num_digits\n", + " \n", + " def __len__(self):\n", + " return 10000 # ...\n", + " \n", + " def get_vocab_size(self):\n", + " return self.num_digits\n", + " \n", + " def get_block_size(self):\n", + " # the length of the sequence that will feed into transformer, \n", + " # containing concatenated input and the output, but -1 because\n", + " # the transformer starts making predictions at the last input element\n", + " return self.length * 2 - 1\n", + "\n", + " def __getitem__(self, idx):\n", + " \n", + " # use rejection sampling to generate an input example from the desired split\n", + " while True:\n", + " # generate some random integers\n", + " inp = torch.randint(self.num_digits, size=(self.length,), dtype=torch.long)\n", + " # half of the time let's try to boost the number of examples that \n", + " # have a large number of repeats, as this is what the model seems to struggle\n", + " # with later in training, and they are kind of rate\n", + " if torch.rand(1).item() < 0.5:\n", + " if inp.unique().nelement() > self.length // 2:\n", + " # too many unqiue digits, re-sample\n", + " continue\n", + " # figure out if this generated example is train or test based on its hash\n", + " h = hash(pickle.dumps(inp.tolist()))\n", + " inp_split = 'test' if h % 4 == 0 else 'train' # designate 25% of examples as test\n", + " if inp_split == self.split:\n", + " break # ok\n", + " \n", + " # solve the task: i.e. sort\n", + " sol = torch.sort(inp)[0]\n", + "\n", + " # concatenate the problem specification and the solution\n", + " cat = torch.cat((inp, sol), dim=0)\n", + "\n", + " # the inputs to the transformer will be the offset sequence\n", + " x = cat[:-1].clone()\n", + " y = cat[1:].clone()\n", + " # we only want to predict at output locations, mask out the loss at the input locations\n", + " y[:self.length-1] = -1\n", + " return x, y\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 -1\n", + "0 -1\n", + "1 -1\n", + "0 -1\n", + "0 -1\n", + "0 0\n", + "0 0\n", + "0 0\n", + "0 0\n", + "0 1\n", + "1 1\n" + ] + } + ], + "source": [ + "# print an example instance of the dataset\n", + "train_dataset = SortDataset('train')\n", + "test_dataset = SortDataset('test')\n", + "x, y = train_dataset[0]\n", + "for a, b in zip(x,y):\n", + " print(int(a),int(b))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "number of parameters: 0.09M\n" + ] + } + ], + "source": [ + "# create a GPT instance\n", + "from mingpt.model import GPT\n", + "\n", + "model_config = GPT.get_default_config()\n", + "model_config.model_type = 'gpt-nano'\n", + "model_config.vocab_size = train_dataset.get_vocab_size()\n", + "model_config.block_size = train_dataset.get_block_size()\n", + "model = GPT(model_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "running on device cuda\n" + ] + } + ], + "source": [ + "# create a Trainer object\n", + "from mingpt.trainer import Trainer\n", + "\n", + "train_config = Trainer.get_default_config()\n", + "train_config.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster\n", + "train_config.max_iters = 2000\n", + "train_config.num_workers = 0\n", + "trainer = Trainer(train_config, model, train_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "iter_dt 0.00ms; iter 0: train loss 1.06407\n", + "iter_dt 18.17ms; iter 100: train loss 0.14712\n", + "iter_dt 18.70ms; iter 200: train loss 0.05315\n", + "iter_dt 19.65ms; iter 300: train loss 0.04404\n", + "iter_dt 31.64ms; iter 400: train loss 0.04724\n", + "iter_dt 18.43ms; iter 500: train loss 0.02521\n", + "iter_dt 19.83ms; iter 600: train loss 0.03352\n", + "iter_dt 19.58ms; iter 700: train loss 0.00539\n", + "iter_dt 18.72ms; iter 800: train loss 0.02057\n", + "iter_dt 18.26ms; iter 900: train loss 0.00360\n", + "iter_dt 18.50ms; iter 1000: train loss 0.00788\n", + "iter_dt 20.64ms; iter 1100: train loss 0.01162\n", + "iter_dt 18.63ms; iter 1200: train loss 0.00963\n", + "iter_dt 18.32ms; iter 1300: train loss 0.02066\n", + "iter_dt 18.40ms; iter 1400: train loss 0.01739\n", + "iter_dt 18.37ms; iter 1500: train loss 0.00376\n", + "iter_dt 18.67ms; iter 1600: train loss 0.00133\n", + "iter_dt 18.38ms; iter 1700: train loss 0.00179\n", + "iter_dt 18.66ms; iter 1800: train loss 0.00079\n", + "iter_dt 18.48ms; iter 1900: train loss 0.00042\n" + ] + } + ], + "source": [ + "def batch_end_callback(trainer):\n", + " if trainer.iter_num % 100 == 0:\n", + " print(f\"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}\")\n", + "trainer.set_callback('on_batch_end', batch_end_callback)\n", + "\n", + "trainer.run()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# now let's perform some evaluation\n", + "model.eval();" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train final score: 5000/5000 = 100.00% correct\n", + "test final score: 5000/5000 = 100.00% correct\n" + ] + } + ], + "source": [ + "def eval_split(trainer, split, max_batches):\n", + " dataset = {'train':train_dataset, 'test':test_dataset}[split]\n", + " n = train_dataset.length # naugy direct access shrug\n", + " results = []\n", + " mistakes_printed_already = 0\n", + " loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False)\n", + " for b, (x, y) in enumerate(loader):\n", + " x = x.to(trainer.device)\n", + " y = y.to(trainer.device)\n", + " # isolate the input pattern alone\n", + " inp = x[:, :n]\n", + " sol = y[:, -n:]\n", + " # let the model sample the rest of the sequence\n", + " cat = model.generate(inp, n, do_sample=False) # using greedy argmax, not sampling\n", + " sol_candidate = cat[:, n:] # isolate the filled in sequence\n", + " # compare the predicted sequence to the true sequence\n", + " correct = (sol == sol_candidate).all(1).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha\n", + " for i in range(x.size(0)):\n", + " results.append(int(correct[i]))\n", + " if not correct[i] and mistakes_printed_already < 3: # only print up to 5 mistakes to get a sense\n", + " mistakes_printed_already += 1\n", + " print(\"GPT claims that %s sorted is %s but gt is %s\" % (inp[i].tolist(), sol_candidate[i].tolist(), sol[i].tolist()))\n", + " if max_batches is not None and b+1 >= max_batches:\n", + " break\n", + " rt = torch.tensor(results, dtype=torch.float)\n", + " print(\"%s final score: %d/%d = %.2f%% correct\" % (split, rt.sum(), len(results), 100*rt.mean()))\n", + " return rt.sum()\n", + "\n", + "# run a lot of examples from both train and test through the model and verify the output correctness\n", + "with torch.no_grad():\n", + " train_score = eval_split(trainer, 'train', max_batches=50)\n", + " test_score = eval_split(trainer, 'test', max_batches=50)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "input sequence : [[0, 0, 2, 1, 0, 1]]\n", + "predicted sorted: [[0, 0, 0, 1, 1, 2]]\n", + "gt sort : [0, 0, 0, 1, 1, 2]\n", + "matches : True\n" + ] + } + ], + "source": [ + "# let's run a random given sequence through the model as well\n", + "n = train_dataset.length # naugy direct access shrug\n", + "inp = torch.tensor([[0, 0, 2, 1, 0, 1]], dtype=torch.long).to(trainer.device)\n", + "assert inp[0].nelement() == n\n", + "with torch.no_grad():\n", + " cat = model.generate(inp, n, do_sample=False)\n", + "sol = torch.sort(inp[0])[0]\n", + "sol_candidate = cat[:, n:]\n", + "print('input sequence :', inp.tolist())\n", + "print('predicted sorted:', sol_candidate.tolist())\n", + "print('gt sort :', sol.tolist())\n", + "print('matches :', bool((sol == sol_candidate).all()))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.10.4 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "3ad933181bd8a04b432d3370b9dc3b0662ad032c4dfaa4e4f1596c548f763858" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/3.test_cases/neuronx-distributed/mingpt/docs/README.md b/3.test_cases/neuronx-distributed/mingpt/docs/README.md new file mode 100644 index 00000000..9debd177 --- /dev/null +++ b/3.test_cases/neuronx-distributed/mingpt/docs/README.md @@ -0,0 +1,147 @@ + +# minGPT + +![mingpt](mingpt.jpg) + +A PyTorch re-implementation of [GPT](https://github.com/openai/gpt-2), both training and inference. minGPT tries to be small, clean, interpretable and educational, as most of the currently available GPT model implementations can a bit sprawling. GPT is not a complicated model and this implementation is appropriately about 300 lines of code (see [mingpt/model.py](mingpt/model.py)). All that's going on is that a sequence of indices feeds into a [Transformer](https://arxiv.org/abs/1706.03762), and a probability distribution over the next index in the sequence comes out. The majority of the complexity is just being clever with batching (both across examples and over sequence length) for efficiency. + +**note (Jan 2023)**: though I may continue to accept and change some details, minGPT is in a semi-archived state. For more recent developments see my rewrite [nanoGPT](https://github.com/karpathy/nanoGPT). Basically, minGPT became referenced across a wide variety of places (notebooks, blogs, courses, books, etc.) which made me less willing to make the bigger changes I wanted to make to move the code forward. I also wanted to change the direction a bit, from a sole focus on education to something that is still simple and hackable but has teeth (reproduces medium-sized industry benchmarks, accepts some tradeoffs to gain runtime efficiency, etc). + +The minGPT library is three files: [mingpt/model.py](mingpt/model.py) contains the actual Transformer model definition, [mingpt/bpe.py](mingpt/bpe.py) contains a mildly refactored Byte Pair Encoder that translates between text and sequences of integers exactly like OpenAI did in GPT, [mingpt/trainer.py](mingpt/trainer.py) is (GPT-independent) PyTorch boilerplate code that trains the model. Then there are a number of demos and projects that use the library in the `projects` folder: + +- `projects/adder` trains a GPT from scratch to add numbers (inspired by the addition section in the GPT-3 paper) +- `projects/chargpt` trains a GPT to be a character-level language model on some input text file +- `demo.ipynb` shows a minimal usage of the `GPT` and `Trainer` in a notebook format on a simple sorting example +- `generate.ipynb` shows how one can load a pretrained GPT2 and generate text given some prompt + +### Library Installation + +If you want to `import mingpt` into your project: + +``` +git clone https://github.com/karpathy/minGPT.git +cd minGPT +pip install -e . +``` + +### Usage + +Here's how you'd instantiate a GPT-2 (124M param version): + +```python +from mingpt.model import GPT +model_config = GPT.get_default_config() +model_config.model_type = 'gpt2' +model_config.vocab_size = 50257 # openai's model vocabulary +model_config.block_size = 1024 # openai's model block_size (i.e. input context length) +model = GPT(model_config) +``` + +And here's how you'd train it: + +```python +# your subclass of torch.utils.data.Dataset that emits example +# torch LongTensor of lengths up to 1024, with integers from [0,50257) +train_dataset = YourDataset() + +from mingpt.trainer import Trainer +train_config = Trainer.get_default_config() +train_config.learning_rate = 5e-4 # many possible options, see the file +train_config.max_iters = 1000 +train_config.batch_size = 32 +trainer = Trainer(train_config, model, train_dataset) +trainer.run() +``` + +See `demo.ipynb` for a more concrete example. + +### Unit tests + +Coverage is not super amazing just yet but: + +``` +python -m unittest discover tests +``` + +### todos + +- add gpt-2 finetuning demo on arbitrary given text file +- add dialog agent demo +- better docs of outcomes for existing projects (adder, chargpt) +- add mixed precision and related training scaling goodies +- distributed training support +- reproduce some benchmarks in projects/, e.g. text8 or other language modeling +- proper logging instead of print statement amateur hour haha +- i probably should have a requirements.txt file... +- it should be possible to load in many other model weights other than just gpt2-\* + +### References + +Code: + +- [openai/gpt-2](https://github.com/openai/gpt-2) has the model definition in TensorFlow, but not the training code +- [openai/image-gpt](https://github.com/openai/image-gpt) has some more modern gpt-3 like modification in its code, good reference as well +- [huggingface/transformers](https://github.com/huggingface/transformers) has a [language-modeling example](https://github.com/huggingface/transformers/tree/master/examples/pytorch/language-modeling). It is full-featured but as a result also somewhat challenging to trace. E.g. some large functions have as much as 90% unused code behind various branching statements that is unused in the default setting of simple language modeling + +Papers + some implementation notes: + +#### Improving Language Understanding by Generative Pre-Training (GPT-1) + +- Our model largely follows the original transformer work +- We trained a 12-layer decoder-only transformer with masked self-attention heads (768 dimensional states and 12 attention heads). For the position-wise feed-forward networks, we used 3072 dimensional inner states. +- Adam max learning rate of 2.5e-4. (later GPT-3 for this model size uses 6e-4) +- LR decay: increased linearly from zero over the first 2000 updates and annealed to 0 using a cosine schedule +- We train for 100 epochs on minibatches of 64 randomly sampled, contiguous sequences of 512 tokens. +- Since layernorm is used extensively throughout the model, a simple weight initialization of N(0, 0.02) was sufficient +- bytepair encoding (BPE) vocabulary with 40,000 merges +- residual, embedding, and attention dropouts with a rate of 0.1 for regularization. +- modified version of L2 regularization proposed in (37), with w = 0.01 on all non bias or gain weights +- For the activation function, we used the Gaussian Error Linear Unit (GELU). +- We used learned position embeddings instead of the sinusoidal version proposed in the original work +- For finetuning: We add dropout to the classifier with a rate of 0.1. learning rate of 6.25e-5 and a batchsize of 32. 3 epochs. We use a linear learning rate decay schedule with warmup over 0.2% of training. λ was set to 0.5. +- GPT-1 model is 12 layers and d_model 768, ~117M params + +#### Language Models are Unsupervised Multitask Learners (GPT-2) + +- LayerNorm was moved to the input of each sub-block, similar to a pre-activation residual network +- an additional layer normalization was added after the final self-attention block. +- modified initialization which accounts for the accumulation on the residual path with model depth is used. We scale the weights of residual layers at initialization by a factor of 1/√N where N is the number of residual layers. (weird because in their released code i can only find a simple use of the old 0.02... in their release of image-gpt I found it used for c_proj, and even then only for attn, not for mlp. huh. https://github.com/openai/image-gpt/blob/master/src/model.py) +- the vocabulary is expanded to 50,257 +- increase the context size from 512 to 1024 tokens +- larger batchsize of 512 is used +- GPT-2 used 48 layers and d_model 1600 (vs. original 12 layers and d_model 768). ~1.542B params + +#### Language Models are Few-Shot Learners (GPT-3) + +- GPT-3: 96 layers, 96 heads, with d_model of 12,288 (175B parameters). +- GPT-1-like: 12 layers, 12 heads, d_model 768 (125M) +- We use the same model and architecture as GPT-2, including the modified initialization, pre-normalization, and reversible tokenization described therein +- we use alternating dense and locally banded sparse attention patterns in the layers of the transformer, similar to the Sparse Transformer +- we always have the feedforward layer four times the size of the bottleneck layer, dff = 4 ∗ dmodel +- all models use a context window of nctx = 2048 tokens. +- Adam with β1 = 0.9, β2 = 0.95, and eps = 10−8 +- All models use weight decay of 0.1 to provide a small amount of regularization. (NOTE: GPT-1 used 0.01 I believe, see above) +- clip the global norm of the gradient at 1.0 +- Linear LR warmup over the first 375 million tokens. Then use cosine decay for learning rate down to 10% of its value, over 260 billion tokens. +- gradually increase the batch size linearly from a small value (32k tokens) to the full value over the first 4-12 billion tokens of training, depending on the model size. +- full 2048-sized time context window is always used, with a special END OF DOCUMENT token delimiter + +#### Generative Pretraining from Pixels (Image GPT) + +- When working with images, we pick the identity permutation πi = i for 1 ≤ i ≤ n, also known as raster order. +- we create our own 9-bit color palette by clustering (R, G, B) pixel values using k-means with k = 512. +- Our largest model, iGPT-XL, contains L = 60 layers and uses an embedding size of d = 3072 for a total of 6.8B parameters. +- Our next largest model, iGPT-L, is essentially identical to GPT-2 with L = 48 layers, but contains a slightly smaller embedding size of d = 1536 (vs 1600) for a total of 1.4B parameters. +- We use the same model code as GPT-2, except that we initialize weights in the layerdependent fashion as in Sparse Transformer (Child et al., 2019) and zero-initialize all projections producing logits. +- We also train iGPT-M, a 455M parameter model with L = 36 and d = 1024 +- iGPT-S, a 76M parameter model with L = 24 and d = 512 (okay, and how many heads? looks like the Github code claims 8) +- When pre-training iGPT-XL, we use a batch size of 64 and train for 2M iterations, and for all other models we use a batch size of 128 and train for 1M iterations. +- Adam with β1 = 0.9 and β2 = 0.95 +- The learning rate is warmed up for one epoch, and then decays to 0 +- We did not use weight decay because applying a small weight decay of 0.01 did not change representation quality. +- iGPT-S lr 0.003 +- No dropout is used. + +### License + +MIT diff --git a/3.test_cases/neuronx-distributed/mingpt/generate.ipynb b/3.test_cases/neuronx-distributed/mingpt/generate.ipynb new file mode 100644 index 00000000..9bc9d7f7 --- /dev/null +++ b/3.test_cases/neuronx-distributed/mingpt/generate.ipynb @@ -0,0 +1,166 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Shows how one can generate text given a prompt and some hyperparameters, using either minGPT or huggingface/transformers" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from transformers import GPT2Tokenizer, GPT2LMHeadModel\n", + "from mingpt.model import GPT\n", + "from mingpt.utils import set_seed\n", + "from mingpt.bpe import BPETokenizer\n", + "set_seed(3407)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "use_mingpt = True # use minGPT or huggingface/transformers model?\n", + "model_type = 'gpt2-xl'\n", + "device = 'cuda'" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "number of parameters: 1557.61M\n" + ] + } + ], + "source": [ + "if use_mingpt:\n", + " model = GPT.from_pretrained(model_type)\n", + "else:\n", + " model = GPT2LMHeadModel.from_pretrained(model_type)\n", + " model.config.pad_token_id = model.config.eos_token_id # suppress a warning\n", + "\n", + "# ship model to device and set to eval mode\n", + "model.to(device)\n", + "model.eval();" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def generate(prompt='', num_samples=10, steps=20, do_sample=True):\n", + " \n", + " # tokenize the input prompt into integer input sequence\n", + " if use_mingpt:\n", + " tokenizer = BPETokenizer()\n", + " if prompt == '':\n", + " # to create unconditional samples...\n", + " # manually create a tensor with only the special <|endoftext|> token\n", + " # similar to what openai's code does here https://github.com/openai/gpt-2/blob/master/src/generate_unconditional_samples.py\n", + " x = torch.tensor([[tokenizer.encoder.encoder['<|endoftext|>']]], dtype=torch.long)\n", + " else:\n", + " x = tokenizer(prompt).to(device)\n", + " else:\n", + " tokenizer = GPT2Tokenizer.from_pretrained(model_type)\n", + " if prompt == '': \n", + " # to create unconditional samples...\n", + " # huggingface/transformers tokenizer special cases these strings\n", + " prompt = '<|endoftext|>'\n", + " encoded_input = tokenizer(prompt, return_tensors='pt').to(device)\n", + " x = encoded_input['input_ids']\n", + " \n", + " # we'll process all desired num_samples in a batch, so expand out the batch dim\n", + " x = x.expand(num_samples, -1)\n", + "\n", + " # forward the model `steps` times to get samples, in a batch\n", + " y = model.generate(x, max_new_tokens=steps, do_sample=do_sample, top_k=40)\n", + " \n", + " for i in range(num_samples):\n", + " out = tokenizer.decode(y[i].cpu().squeeze())\n", + " print('-'*80)\n", + " print(out)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--------------------------------------------------------------------------------\n", + "Andrej Karpathy, the chief of the criminal investigation department, said during a news conference, \"We still have a lot of\n", + "--------------------------------------------------------------------------------\n", + "Andrej Karpathy, the man whom most of America believes is the architect of the current financial crisis. He runs the National Council\n", + "--------------------------------------------------------------------------------\n", + "Andrej Karpathy, the head of the Department for Regional Reform of Bulgaria and an MP in the centre-right GERB party\n", + "--------------------------------------------------------------------------------\n", + "Andrej Karpathy, the former head of the World Bank's IMF department, who worked closely with the IMF. The IMF had\n", + "--------------------------------------------------------------------------------\n", + "Andrej Karpathy, the vice president for innovation and research at Citi who oversaw the team's work to make sense of the\n", + "--------------------------------------------------------------------------------\n", + "Andrej Karpathy, the CEO of OOAK Research, said that the latest poll indicates that it won't take much to\n", + "--------------------------------------------------------------------------------\n", + "Andrej Karpathy, the former prime minister of Estonia was at the helm of a three-party coalition when parliament met earlier this\n", + "--------------------------------------------------------------------------------\n", + "Andrej Karpathy, the director of the Institute of Economic and Social Research, said if the rate of return is only 5 per\n", + "--------------------------------------------------------------------------------\n", + "Andrej Karpathy, the minister of commerce for Latvia's western neighbour: \"The deal means that our two countries have reached more\n", + "--------------------------------------------------------------------------------\n", + "Andrej Karpathy, the state's environmental protection commissioner. \"That's why we have to keep these systems in place.\"\n", + "\n" + ] + } + ], + "source": [ + "generate(prompt='Andrej Karpathy, the', num_samples=10, steps=20)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.10.4 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "3ad933181bd8a04b432d3370b9dc3b0662ad032c4dfaa4e4f1596c548f763858" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/3.test_cases/neuronx-distributed/mingpt/mingpt.jpg b/3.test_cases/neuronx-distributed/mingpt/mingpt.jpg new file mode 100644 index 00000000..8070bcb8 Binary files /dev/null and b/3.test_cases/neuronx-distributed/mingpt/mingpt.jpg differ diff --git a/3.test_cases/neuronx-distributed/mingpt/mingpt/__init__.py b/3.test_cases/neuronx-distributed/mingpt/mingpt/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/3.test_cases/neuronx-distributed/mingpt/mingpt/bpe.py b/3.test_cases/neuronx-distributed/mingpt/mingpt/bpe.py new file mode 100644 index 00000000..b8468ef9 --- /dev/null +++ b/3.test_cases/neuronx-distributed/mingpt/mingpt/bpe.py @@ -0,0 +1,319 @@ +""" +bpe is short for Byte Pair Encoder. It translates arbitrary utf-8 strings into +sequences of integers, where each integer represents small chunks of commonly +occuring characters. This implementation is based on openai's gpt2 encoder.py: +https://github.com/openai/gpt-2/blob/master/src/encoder.py +but was mildly modified because the original implementation is a bit confusing. +I also tried to add as many comments as possible, my own understanding of what's +going on. +""" + +import os +import json +import regex as re +import requests + +import torch + +# ----------------------------------------------------------------------------- + +def bytes_to_unicode(): + """ + Every possible byte (really an integer 0..255) gets mapped by OpenAI to a unicode + character that represents it visually. Some bytes have their appearance preserved + because they don't cause any trouble. These are defined in list bs. For example: + chr(33) returns "!", so in the returned dictionary we simply have d[33] -> "!". + However, chr(0), for example, is '\x00', which looks ugly. So OpenAI maps these + bytes, into new characters in a range where chr() returns a single nice character. + So in the final dictionary we have d[0] -> 'Ā' instead, which is just chr(0 + 2**8). + In particular, the space character is 32, which we can see by ord(' '). Instead, + this function will shift space (32) by 256 to 288, so d[32] -> 'Ġ'. + So this is just a simple one-to-one mapping of bytes 0..255 into unicode characters + that "look nice", either in their original form, or a funny shifted character + like 'Ā', or 'Ġ', etc. + """ + # the 188 integers that render fine in their original form and need no shifting + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] # all integers b in bs will simply map to chr(b) in the output dict + # now get the representations of the other 68 integers that do need shifting + # each will get mapped chr(256 + n), where n will grow from 0...67 in the loop + n = 0 + for b in range(2**8): + if b not in bs: + # if this byte is "ugly" then map it to the next available "nice" character + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + d = dict(zip(bs, cs)) + return d + +def get_pairs(word): + """ + Return all bigrams as a set of tuples, of consecutive elements in the iterable word. + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + +class Encoder: + + def __init__(self, encoder, bpe_merges): + # byte encoder/decoder + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} + # bpe token encoder/decoder + self.encoder = encoder + self.decoder = {v:k for k,v in self.encoder.items()} + # bpe merge list that defines the bpe "tree", of tuples (a,b) that are to merge to token ab + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + # the splitting pattern used for pre-tokenization + # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions <-- original openai comment + """ + ok so what is this regex looking for, exactly? + python re reference: https://docs.python.org/3/library/re.html + - the vertical bars | is OR, so re.findall will chunkate text as the pieces match, from left to right + - '\'s' would split up things like Andrej's -> (Andrej, 's) + - ' ?\p{L}': optional space followed by 1+ unicode code points in the category "letter" + - ' ?\p{N}': optional space followed by 1+ unicode code points in the category "number" + - ' ?[^\s\p{L}\p{N}]+': optional space, then 1+ things that are NOT a whitespace, letter or number + - '\s+(?!\S)': 1+ whitespace characters (e.g. space or tab or etc) UNLESS they are followed by non-whitespace + so this will consume whitespace characters in a sequence but exclude the last whitespace in + that sequence. that last whitespace has the opportunity to then match the optional ' ?' in + earlier patterns. + - '\s+': 1+ whitespace characters, intended probably to catch a full trailing sequence of whitespaces at end of string + So TLDR: + - we are special casing a few common apostrophe constructs ('s, 't, 're, ...) and making those into separate tokens + - we then separate out strings into consecutive chunks of 1) letters, 2) numbers, 3) non-letter-numbers, 4) whitespaces + """ + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + self.cache = {} + + def bpe(self, token): + """ + this function uses self.bpe_ranks to iteratively merge all the possible bpe tokens + up the tree. token is a string of one individual 'word' (after regex tokenization) + and after byte encoding, e.g. 'Ġthere'. + """ + # token is a string of one individual 'word', after byte encoding, e.g. 'Ġthere' + + # memoization, for efficiency + if token in self.cache: + return self.cache[token] + + word = tuple(token) # individual characters that make up the token, in a tuple + pairs = get_pairs(word) # get all bigrams + + if not pairs: + return token + + while True: + + # find the next lowest rank bigram that can be merged + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break # no more bigrams are eligible to be merged + first, second = bigram + + # we will now replace all occurences of (first, second) in the list of current + # words into one merged token first_second, in the output list new_words + new_word = [] + i = 0 + while i < len(word): + + # find the next occurence of first in the sequence of current words + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + # if this occurence is also followed by second, then merge them into one + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + + # all occurences of (first, second) have been merged to first_second + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + + # concat all words into a string, and use ' ' as the separator. Note that + # by now all characters have been byte encoded, guaranteeing that ' ' is + # not used in the actual data and is a 'special' delimiter character + word = ' '.join(word) + + # cache the result and return + self.cache[token] = word + return word + + def encode(self, text): + """ string goes in, list of integers comes out """ + bpe_idx = [] + # pre-tokenize the input text into string tokens (words, roughly speaking) + tokens = re.findall(self.pat, text) + # process each token into BPE integers + for token in tokens: + # encode the token as a bytes (b'') object + token_bytes = token.encode('utf-8') + # translate all bytes to their unicode string representation and flatten + token_translated = ''.join(self.byte_encoder[b] for b in token_bytes) + # perform all the applicable bpe merges according to self.bpe_ranks + token_merged = self.bpe(token_translated).split(' ') + # translate all bpe tokens to integers + token_ix = [self.encoder[bpe_token] for bpe_token in token_merged] + # extend our running list of all output integers + bpe_idx.extend(token_ix) + return bpe_idx + + def encode_and_show_work(self, text): + """ debugging function, same as encode but returns all intermediate work """ + bpe_idx = [] + parts = [] + tokens = re.findall(self.pat, text) + for token in tokens: + token_bytes = token.encode('utf-8') + token_translated = ''.join(self.byte_encoder[b] for b in token_bytes) + token_merged = self.bpe(token_translated).split(' ') + token_ix = [self.encoder[bpe_token] for bpe_token in token_merged] + bpe_idx.extend(token_ix) + parts.append({ + 'token': token, + 'token_bytes': token_bytes, + 'token_translated': token_translated, + 'token_merged': token_merged, + 'token_ix': token_ix, + }) + out = { + 'bpe_idx': bpe_idx, # the actual output sequence + 'tokens': tokens, # result of pre-tokenization + 'parts': parts, # intermediates for each token part + } + return out + + def decode(self, bpe_idx): + """ list of integers comes in, string comes out """ + # inverse map the integers to get the tokens + tokens_merged = [self.decoder[token] for token in bpe_idx] + # inverse the byte encoder, e.g. recovering 'Ġ' -> ' ', and get the bytes + tokens_flat = ''.join(tokens_merged) + tokens_bytes = bytearray([self.byte_decoder[c] for c in tokens_flat]) + # recover the full utf-8 string + text = tokens_bytes.decode('utf-8', errors='replace') + return text + +def get_file(local_file, remote_file): + """ downloads remote_file to local_file if necessary """ + if not os.path.isfile(local_file): + print(f"downloading {remote_file} to {local_file}") + response = requests.get(remote_file) + open(local_file, "wb").write(response.content) + +def get_encoder(): + """ + Returns an instance of the GPT BPE Encoder/Decoder + and handles caching of "database" files. + """ + home_dir = os.path.expanduser('~') + cache_dir = os.path.join(home_dir, '.cache', 'mingpt') + os.makedirs(cache_dir, exist_ok=True) + + # load encoder.json that has the raw mappings from token -> bpe index + encoder_local_file = os.path.join(cache_dir, 'encoder.json') + encoder_remote_file = 'https://openaipublic.blob.core.windows.net/gpt-2/models/124M/encoder.json' + get_file(encoder_local_file, encoder_remote_file) + with open(encoder_local_file, 'r') as f: + encoder = json.load(f) + assert len(encoder) == 50257 # 256 individual byte tokens, 50,000 merged tokens, and 1 special <|endoftext|> token + + # load vocab.bpe that contains the bpe merges, i.e. the bpe tree structure + # in the form tuples (a, b), that indicate that (a, b) is to be merged to one token ab + vocab_local_file = os.path.join(cache_dir, 'vocab.bpe') + vocab_remote_file = 'https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe' + get_file(vocab_local_file, vocab_remote_file) + with open(vocab_local_file, 'r', encoding="utf-8") as f: + bpe_data = f.read() + # light postprocessing: strip the version on first line and the last line is a blank + bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] + assert len(bpe_merges) == 50000 # 50,000 merged tokens + + # construct the Encoder object and return + enc = Encoder(encoder, bpe_merges) + return enc + +# ----------------------------------------------------------------------------- + +class BPETokenizer: + """ PyTorch-aware class that wraps the Encoder above """ + + def __init__(self): + self.encoder = get_encoder() + + def __call__(self, text, return_tensors='pt'): + # PyTorch only; here because we want to match huggingface/transformers interface + assert return_tensors == 'pt' + # single string input for now, in the future potentially a list of strings + assert isinstance(text, str) + # encode and create a "batch dimension" of 1 + idx = [self.encoder.encode(text)] + # wrap into PyTorch tensor + out = torch.tensor(idx, dtype=torch.long) + return out + + def decode(self, idx): + # ensure a simple 1D tensor for now + assert idx.ndim == 1 + # decode indices to text + text = self.encoder.decode(idx.tolist()) + return text + + +if __name__ == '__main__': + + # here is an encoding example + text = "Hello!! I'm Andrej Karpathy. It's 2022. w00t :D 🤗" + e = get_encoder() + r = e.encode_and_show_work(text) + + print("Original text is:") + print(text) + print("First the text gets pre-tokenized, broken up into chunks, the outcome is:") + print(r['tokens']) + # ['Hello', '!!', ' I', "'m", ' Andrej', ' Karpathy', '.', ' It', "'s", ' 2022', '.', ' w', '00', 't', ' :', 'D', ' 🤗'] + print("Then we iterate over each chunk and process them in turn...") + for part in r['parts']: + print(part) + # {'token': 'Hello', 'token_bytes': b'Hello', 'token_translated': 'Hello', 'token_merged': ['Hello'], 'token_ix': [15496]} + # {'token': '!!', 'token_bytes': b'!!', 'token_translated': '!!', 'token_merged': ['!!'], 'token_ix': [3228]} + # {'token': ' I', 'token_bytes': b' I', 'token_translated': 'ĠI', 'token_merged': ['ĠI'], 'token_ix': [314]} + # {'token': "'m", 'token_bytes': b"'m", 'token_translated': "'m", 'token_merged': ["'m"], 'token_ix': [1101]} + # {'token': ' Andrej', 'token_bytes': b' Andrej', 'token_translated': 'ĠAndrej', 'token_merged': ['ĠAndre', 'j'], 'token_ix': [10948, 73]} + # {'token': ' Karpathy', 'token_bytes': b' Karpathy', 'token_translated': 'ĠKarpathy', 'token_merged': ['ĠK', 'arp', 'athy'], 'token_ix': [509, 5117, 10036]} + # {'token': '.', 'token_bytes': b'.', 'token_translated': '.', 'token_merged': ['.'], 'token_ix': [13]} + # {'token': ' It', 'token_bytes': b' It', 'token_translated': 'ĠIt', 'token_merged': ['ĠIt'], 'token_ix': [632]} + # {'token': "'s", 'token_bytes': b"'s", 'token_translated': "'s", 'token_merged': ["'s"], 'token_ix': [338]} + # {'token': ' 2022', 'token_bytes': b' 2022', 'token_translated': 'Ġ2022', 'token_merged': ['Ġ2022'], 'token_ix': [33160]} + # {'token': '.', 'token_bytes': b'.', 'token_translated': '.', 'token_merged': ['.'], 'token_ix': [13]} + # {'token': ' w', 'token_bytes': b' w', 'token_translated': 'Ġw', 'token_merged': ['Ġw'], 'token_ix': [266]} + # {'token': '00', 'token_bytes': b'00', 'token_translated': '00', 'token_merged': ['00'], 'token_ix': [405]} + # {'token': 't', 'token_bytes': b't', 'token_translated': 't', 'token_merged': ['t'], 'token_ix': [83]} + # {'token': ' :', 'token_bytes': b' :', 'token_translated': 'Ġ:', 'token_merged': ['Ġ:'], 'token_ix': [1058]} + # {'token': 'D', 'token_bytes': b'D', 'token_translated': 'D', 'token_merged': ['D'], 'token_ix': [35]} + # {'token': ' 🤗', 'token_bytes': b' \xf0\x9f\xa4\x97', 'token_translated': 'ĠðŁ¤Ĺ', 'token_merged': ['ĠðŁ', '¤', 'Ĺ'], 'token_ix': [12520, 97, 245]} + # (refer to the code inside Encoder.encode for what these intermediates are) + print("and the final outcome is concatenating and flattening all the token_ix:") + print(r['bpe_idx']) + # [15496, 3228, 314, 1101, 10948, 73, 509, 5117, 10036, 13, 632, 338, 33160, 13, 266, 405, 83, 1058, 35, 12520, 97, 245] + # this would then become the integer input sequence to the transformer + print("ready to feed into a Transformer!") diff --git a/3.test_cases/neuronx-distributed/mingpt/mingpt/configs.py b/3.test_cases/neuronx-distributed/mingpt/mingpt/configs.py new file mode 100644 index 00000000..3557b183 --- /dev/null +++ b/3.test_cases/neuronx-distributed/mingpt/mingpt/configs.py @@ -0,0 +1,89 @@ +from ast import literal_eval + +class CfgNode: + """ a lightweight configuration class inspired by yacs """ + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + def __str__(self): + return self._str_helper(0) + + def _str_helper(self, indent): + """ need to have a helper to support nested indentation for pretty printing """ + parts = [] + for k, v in self.__dict__.items(): + if isinstance(v, CfgNode): + parts.append("%s:\n" % k) + parts.append(v._str_helper(indent + 1)) + else: + parts.append("%s: %s\n" % (k, v)) + parts = [' ' * (indent * 4) + p for p in parts] + return "".join(parts) + + def to_dict(self): + """ return a dict representation of the config """ + return { k: v.to_dict() if isinstance(v, CfgNode) else v for k, v in self.__dict__.items() } + + def merge_from_dict(self, d): + self.__dict__.update(d) + + def merge_from_args(self, args): + """ + update the configuration from a list of strings that is expected + to come from the command line, i.e. sys.argv[1:]. + + The arguments are expected to be in the form of `--arg=value`, and + the arg can use . to denote nested sub-attributes. Example: + + --model.n_layer=10 --trainer.batch_size=32 + """ + for arg in args: + + keyval = arg.split('=') + assert len(keyval) == 2, "expecting each override arg to be of form --arg=value, got %s" % arg + key, val = keyval # unpack + + # first translate val into a python object + try: + val = literal_eval(val) + """ + need some explanation here. + - if val is simply a string, literal_eval will throw a ValueError + - if val represents a thing (like an 3, 3.14, [1,2,3], False, None, etc.) it will get created + """ + except ValueError: + pass + + # find the appropriate object to insert the attribute into + assert key[:2] == '--' + key = key[2:] # strip the '--' + keys = key.split('.') + obj = self + for k in keys[:-1]: + obj = getattr(obj, k) + leaf_key = keys[-1] + + # ensure that this attribute exists + assert hasattr(obj, leaf_key), f"{key} is not an attribute that exists in the config" + + # overwrite the attribute + print("command line overwriting config attribute %s with %s" % (key, val)) + setattr(obj, leaf_key, val) + +class TrainConfig: + @staticmethod + def get_default_config(): + C = CfgNode() + # device to train on + C.device = 'auto' + # dataloder parameters + C.num_workers = 0 + # optimizer parameters + C.batch_size = 8 + C.learning_rate = 5e-4 + C.betas = (0.9, 0.95) + C.max_iters = 8000 + C.weight_decay = 0.1 # only applied on matmul weights + C.grad_norm_clip = 1.0 + return C diff --git a/3.test_cases/neuronx-distributed/mingpt/mingpt/datasets.py b/3.test_cases/neuronx-distributed/mingpt/mingpt/datasets.py new file mode 100644 index 00000000..cbb7bbc9 --- /dev/null +++ b/3.test_cases/neuronx-distributed/mingpt/mingpt/datasets.py @@ -0,0 +1,135 @@ +import pickle + +import torch +from torch.utils.data import Dataset + +class AdditionDataset(Dataset): + """ + Creates n-digit addition problems. For example, if n=2, then an example + addition problem would be to add 85 + 50 = 135. This problem would be + represented as the following string for the GPT: + + "8550531" + + This is because: + - we are discarding the + and =, which are not necessary. We just encode the digits + of the input numbers concatenated together. + - the result 135 is encoded backwards to make the addition easier to learn for the + GPT model, because of how the addition algorithm works. + + As one more example, the problem 6 + 39 = 45 would be encoded as: + + "0639054" + + where you will notice that we are padding with zeros to make sure that we always + produce strings of the exact same size: n + n + (n + 1). When n=2, this is 7. + At test time, we will feed in an addition problem by giving the first 2n digits, + and hoping that the GPT model completes the sequence with the next (n+1) digits + correctly. + """ + + def __init__(self, split, ndigit=2): + self.split = split # train/test + + # split up all addition problems into either training data or test data + self.ndigit = ndigit + assert ndigit <= 3, "the lines below would be very memory inefficient, in future maybe refactor to support" + num = (10**ndigit)**2 # total number of possible addition problems with ndigit numbers + rng = torch.Generator() + rng.manual_seed(1337) + perm = torch.randperm(num, generator=rng) + num_test = min(int(num*0.2), 500) # 20% of the whole dataset, or only up to 500 + self.ixes = perm[:num_test] if split == 'test' else perm[num_test:] + + def get_vocab_size(self): + return 10 # digits 0..9 + + def get_block_size(self): + # a,b,a+b, and +1 due to potential carry overflow, + # but then also -1 because very last digit doesn't ever plug back + # as there is no explicit token to predict, it is implied + return 3*self.ndigit + 1 - 1 + + def __len__(self): + return self.ixes.nelement() + + def __getitem__(self, idx): + ndigit = self.ndigit + # given a problem index idx, first recover the associated a + b + idx = self.ixes[idx].item() + nd = 10**ndigit + a = idx // nd + b = idx % nd + # calculate the "label" of the addition problem a + b + c = a + b + # encode the digits of a, b, c into strings + astr = f'%0{ndigit}d' % a + bstr = f'%0{ndigit}d' % b + cstr = (f'%0{ndigit+1}d' % c)[::-1] # reverse c to make addition easier + render = astr + bstr + cstr + dix = [int(s) for s in render] # convert each character to its token index + # x will be input to GPT and y will be the associated expected outputs + x = torch.tensor(dix[:-1], dtype=torch.long) + y = torch.tensor(dix[1:], dtype=torch.long) # predict the next token in the sequence + y[:ndigit*2-1] = -1 # we will only train in the output locations. -1 will mask loss to zero + return x, y + +class SortDataset(Dataset): + """ + Dataset for the Sort problem. E.g. for problem length 6: + Input: 0 0 2 1 0 1 -> Output: 0 0 0 1 1 2 + Which will feed into the transformer concatenated as: + input: 0 0 2 1 0 1 0 0 0 1 1 + output: I I I I I 0 0 0 1 1 2 + where I is "ignore", as the transformer is reading the input sequence + """ + + def __init__(self, split, length=8, num_digits=4): + assert split in {'train', 'test'} + self.split = split + self.length = length + self.num_digits = num_digits + + def __len__(self): + return 10000 # ... + + def get_vocab_size(self): + return self.num_digits + + def get_block_size(self): + # the length of the sequence that will feed into transformer, + # containing concatenated input and the output, but -1 because + # the transformer starts making predictions at the last input element + return self.length * 2 - 1 + + def __getitem__(self, idx): + + # use rejection sampling to generate an input example from the desired split + while True: + # generate some random integers + inp = torch.randint(self.num_digits, size=(self.length,), dtype=torch.long) + # half of the time let's try to boost the number of examples that + # have a large number of repeats, as this is what the model seems to struggle + # with later in training, and they are kind of rate + if torch.rand(1).item() < 0.5: + if inp.unique().nelement() > self.length // 2: + # too many unqiue digits, re-sample + continue + # figure out if this generated example is train or test based on its hash + h = hash(pickle.dumps(inp.tolist())) + inp_split = 'test' if h % 4 == 0 else 'train' # designate 25% of examples as test + if inp_split == self.split: + break # ok + + # solve the task: i.e. sort + sol = torch.sort(inp)[0] + + # concatenate the problem specification and the solution + cat = torch.cat((inp, sol), dim=0) + + # the inputs to the transformer will be the offset sequence + x = cat[:-1].clone() + y = cat[1:].clone() + # we only want to predict at output locations, mask out the loss at the input locations + y[:self.length-1] = -1 + return x, y diff --git a/3.test_cases/neuronx-distributed/mingpt/mingpt/model.py b/3.test_cases/neuronx-distributed/mingpt/mingpt/model.py new file mode 100644 index 00000000..4e41e480 --- /dev/null +++ b/3.test_cases/neuronx-distributed/mingpt/mingpt/model.py @@ -0,0 +1,305 @@ +""" +Full definition of a GPT Language Model, all of it in this single file. + +References: +1) the official GPT-2 TensorFlow implementation released by OpenAI: +https://github.com/openai/gpt-2/blob/master/src/model.py +2) huggingface/transformers PyTorch implementation: +https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py +""" + +import math + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from mingpt.configs import CfgNode as CN + +# ----------------------------------------------------------------------------- + +class NewGELU(nn.Module): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). + Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 + """ + def forward(self, x): + return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + +class CausalSelfAttention(nn.Module): + """ + A vanilla multi-head masked self-attention layer with a projection at the end. + It is possible to use torch.nn.MultiheadAttention here but I am including an + explicit implementation here to show that there is nothing too scary here. + """ + + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd) + # regularization + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) + .view(1, 1, config.block_size, config.block_size)) + self.n_head = config.n_head + self.n_embd = config.n_embd + + def forward(self, x): + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k ,v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + +class Block(nn.Module): + """ an unassuming Transformer block """ + + def __init__(self, config): + super().__init__() + self.ln_1 = nn.LayerNorm(config.n_embd) + self.attn = CausalSelfAttention(config) + self.ln_2 = nn.LayerNorm(config.n_embd) + self.mlp = nn.ModuleDict(dict( + c_fc = nn.Linear(config.n_embd, 4 * config.n_embd), + c_proj = nn.Linear(4 * config.n_embd, config.n_embd), + act = NewGELU(), + dropout = nn.Dropout(config.resid_pdrop), + )) + m = self.mlp + self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward + + def forward(self, x): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlpf(self.ln_2(x)) + return x + +class GPT(nn.Module): + """ GPT Language Model """ + + @staticmethod + def get_default_config(): + C = CN() + # either model_type or (n_layer, n_head, n_embd) must be given in the config + C.model_type = 'gpt' + C.n_layer = None + C.n_head = None + C.n_embd = None + # these options must be filled in externally + C.vocab_size = None + C.block_size = None + # dropout hyperparameters + C.embd_pdrop = 0.1 + C.resid_pdrop = 0.1 + C.attn_pdrop = 0.1 + return C + + def __init__(self, config): + super().__init__() + assert config.vocab_size is not None + assert config.block_size is not None + self.block_size = config.block_size + + type_given = config.model_type is not None + params_given = all([config.n_layer is not None, config.n_head is not None, config.n_embd is not None]) + assert type_given ^ params_given # exactly one of these (XOR) + if type_given: + # translate from model_type to detailed configuration + config.merge_from_dict({ + # names follow the huggingface naming conventions + # GPT-1 + 'openai-gpt': dict(n_layer=12, n_head=12, n_embd=768), # 117M params + # GPT-2 configs + 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params + 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params + 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params + 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params + # Gophers + 'gopher-44m': dict(n_layer=8, n_head=16, n_embd=512), + # (there are a number more...) + # I made these tiny models up + 'gpt-mini': dict(n_layer=6, n_head=6, n_embd=192), + 'gpt-micro': dict(n_layer=4, n_head=4, n_embd=128), + 'gpt-nano': dict(n_layer=3, n_head=3, n_embd=48), + }[config.model_type]) + + self.transformer = nn.ModuleDict(dict( + wte = nn.Embedding(config.vocab_size, config.n_embd), + wpe = nn.Embedding(config.block_size, config.n_embd), + drop = nn.Dropout(config.embd_pdrop), + h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), + ln_f = nn.LayerNorm(config.n_embd), + )) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper + self.apply(self._init_weights) + for pn, p in self.named_parameters(): + if pn.endswith('c_proj.weight'): + torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) + + # report number of parameters (note we don't count the decoder parameters in lm_head) + n_params = sum(p.numel() for p in self.transformer.parameters()) + print("number of parameters: %.2fM" % (n_params/1e6,)) + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + torch.nn.init.zeros_(module.bias) + torch.nn.init.ones_(module.weight) + + @classmethod + def from_pretrained(cls, model_type): + """ + Initialize a pretrained GPT model by copying over the weights + from a huggingface/transformers checkpoint. + """ + assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} + from transformers import GPT2LMHeadModel + + # create a from-scratch initialized minGPT model + config = cls.get_default_config() + config.model_type = model_type + config.vocab_size = 50257 # openai's model vocabulary + config.block_size = 1024 # openai's model block_size + model = GPT(config) + sd = model.state_dict() + + # init a huggingface/transformers model + model_hf = GPT2LMHeadModel.from_pretrained(model_type) + sd_hf = model_hf.state_dict() + + # copy while ensuring all of the parameters are aligned and match in names and shapes + keys = [k for k in sd_hf if not k.endswith('attn.masked_bias')] # ignore these + transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] + # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla nn.Linear. + # this means that we have to transpose these weights when we import them + assert len(keys) == len(sd) + for k in keys: + if any(k.endswith(w) for w in transposed): + # special treatment for the Conv1D weights we need to transpose + assert sd_hf[k].shape[::-1] == sd[k].shape + with torch.no_grad(): + sd[k].copy_(sd_hf[k].t()) + else: + # vanilla copy over the other parameters + assert sd_hf[k].shape == sd[k].shape + with torch.no_grad(): + sd[k].copy_(sd_hf[k]) + + return model + + def configure_optimizers(self, train_config): + """ + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + """ + + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear, ) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in self.named_modules(): + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + # random note: because named_modules and named_parameters are recursive + # we will see the same tensors p many many times. but doing it this way + # allows us to know which parent module any tensor p belongs to... + if pn.endswith('bias'): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in self.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) + assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ + % (str(param_dict.keys() - union_params), ) + + # create the pytorch optimizer object + optim_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, + {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) + return optimizer + + def forward(self, idx, targets=None): + device = idx.device + b, t = idx.size() + assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}" + pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) + + # forward the GPT model itself + tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) + x = self.transformer.drop(tok_emb + pos_emb) + for block in self.transformer.h: + x = block(x) + x = self.transformer.ln_f(x) + logits = self.lm_head(x) + + return logits + + @torch.no_grad() + def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None): + """ + Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete + the sequence max_new_tokens times, feeding the predictions back into the model each time. + Most likely you'll want to make sure to be in model.eval() mode of operation for this. + """ + for _ in range(max_new_tokens): + # if the sequence context is growing too long we must crop it at block_size + idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:] + # forward the model to get the logits for the index in the sequence + logits = self(idx_cond) + # pluck the logits at the final step and scale by desired temperature + logits = logits[:, -1, :] / temperature + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(logits, top_k) + logits[logits < v[:, [-1]]] = -float('Inf') + # apply softmax to convert logits to (normalized) probabilities + probs = F.softmax(logits, dim=-1) + # either sample from the distribution or take the most likely element + if do_sample: + idx_next = torch.multinomial(probs, num_samples=1) + else: + _, idx_next = torch.topk(probs, k=1, dim=-1) + # append sampled index to the running sequence and continue + idx = torch.cat((idx, idx_next), dim=1) + + return idx diff --git a/3.test_cases/neuronx-distributed/mingpt/mingpt/trainer.py b/3.test_cases/neuronx-distributed/mingpt/mingpt/trainer.py new file mode 100644 index 00000000..4ac6d266 --- /dev/null +++ b/3.test_cases/neuronx-distributed/mingpt/mingpt/trainer.py @@ -0,0 +1,110 @@ +""" +Simple training loop; Boilerplate that could apply to any arbitrary neural network, +so nothing in this file really has anything to do with GPT specifically. +""" + +import time +from collections import defaultdict + +import torch +from torch.utils.data.dataloader import DataLoader +from mingpt.utils import CfgNode as CN + +class Trainer: + + @staticmethod + def get_default_config(): + C = CN() + # device to train on + C.device = 'auto' + # dataloder parameters + C.num_workers = 0 + # optimizer parameters + C.max_iters = None + C.batch_size = 64 + C.learning_rate = 3e-4 + C.betas = (0.9, 0.95) + C.max_iters = 2000 + C.weight_decay = 0.1 # only applied on matmul weights + C.grad_norm_clip = 1.0 + return C + + def __init__(self, config, model, train_dataset): + self.config = config + self.model = model + self.optimizer = None + self.train_dataset = train_dataset + self.callbacks = defaultdict(list) + + # determine the device we'll train on + if config.device == 'auto': + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + else: + self.device = config.device + self.model = self.model.to(self.device) + print("running on device", self.device) + + # variables that will be assigned to trainer class later for logging and etc + self.iter_num = 0 + self.iter_time = 0.0 + self.iter_dt = 0.0 + + def add_callback(self, onevent: str, callback): + self.callbacks[onevent].append(callback) + + def set_callback(self, onevent: str, callback): + self.callbacks[onevent] = [callback] + + def trigger_callbacks(self, onevent: str): + for callback in self.callbacks.get(onevent, []): + callback(self) + + def run(self): + model, config = self.model, self.config + + # setup the optimizer + self.optimizer = model.configure_optimizers(config) + + # setup the dataloader + train_loader = DataLoader( + self.train_dataset, + sampler=torch.utils.data.RandomSampler(self.train_dataset, replacement=True, num_samples=int(1e10)), + shuffle=False, + pin_memory=True, + batch_size=config.batch_size, + num_workers=config.num_workers, + ) + + model.train() + self.iter_num = 0 + self.iter_time = time.time() + data_iter = iter(train_loader) + while True: + + # fetch the next batch (x, y) and re-init iterator if needed + try: + batch = next(data_iter) + except StopIteration: + data_iter = iter(train_loader) + batch = next(data_iter) + batch = [t.to(self.device) for t in batch] + x, y = batch + + # forward the model + logits, self.loss = model(x, y) + + # backprop and update the parameters + model.zero_grad(set_to_none=True) + self.loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) + self.optimizer.step() + + self.trigger_callbacks('on_batch_end') + self.iter_num += 1 + tnow = time.time() + self.iter_dt = tnow - self.iter_time + self.iter_time = tnow + + # termination conditions + if config.max_iters is not None and self.iter_num >= config.max_iters: + break diff --git a/3.test_cases/neuronx-distributed/mingpt/mingpt/utils.py b/3.test_cases/neuronx-distributed/mingpt/mingpt/utils.py new file mode 100644 index 00000000..b6e68a49 --- /dev/null +++ b/3.test_cases/neuronx-distributed/mingpt/mingpt/utils.py @@ -0,0 +1,52 @@ + +import os +import sys +import json +import random + +import numpy as np +import torch + +# ----------------------------------------------------------------------------- + +def evaluate(model, loader, n, split="train", max_batches=None): + results = [] + mistakes_printed_already = 0 + for b, (x, y) in enumerate(loader): + # isolate the input pattern alone + inp = x[:, :n] + sol = y[:, -n:] + # let the model sample the rest of the sequence + cat = model.generate(inp, n, do_sample=False) # using greedy argmax, not sampling + sol_candidate = cat[:, n:] # isolate the filled in sequence + # compare the predicted sequence to the true sequence + correct = (sol == sol_candidate).all(1).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha + for i in range(x.size(0)): + results.append(int(correct[i])) + if not correct[i] and mistakes_printed_already < 3: # only print up to 5 mistakes to get a sense + mistakes_printed_already += 1 + print("GPT claims that %s sorted is %s but gt is %s" % (inp[i].tolist(), sol_candidate[i].tolist(), sol[i].tolist())) + if max_batches is not None and b+1 >= max_batches: + break + rt = torch.tensor(results, dtype=torch.float) + print("%s final score: %d/%d = %.2f%% correct" % (split, rt.sum(), len(results), 100*rt.mean())) + return rt.sum() + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +def setup_logging(config): + """ monotonous bookkeeping """ + work_dir = config.system.work_dir + # create the work directory if it doesn't already exist + os.makedirs(work_dir, exist_ok=True) + # log the args (if any) + with open(os.path.join(work_dir, 'args.txt'), 'w') as f: + f.write(' '.join(sys.argv)) + # log the config itself + with open(os.path.join(work_dir, 'config.json'), 'w') as f: + f.write(json.dumps(config.to_dict(), indent=4)) diff --git a/3.test_cases/neuronx-distributed/mingpt/setup.py b/3.test_cases/neuronx-distributed/mingpt/setup.py new file mode 100644 index 00000000..9a2d64f6 --- /dev/null +++ b/3.test_cases/neuronx-distributed/mingpt/setup.py @@ -0,0 +1,12 @@ +from setuptools import setup + +setup(name='minGPT', + version='0.0.1', + author='Andrej Karpathy', + packages=['mingpt'], + description='A PyTorch re-implementation of GPT', + license='MIT', + install_requires=[ + 'torch', + ], +) diff --git a/3.test_cases/neuronx-distributed/mingpt/tests/test_huggingface_import.py b/3.test_cases/neuronx-distributed/mingpt/tests/test_huggingface_import.py new file mode 100644 index 00000000..dab52a82 --- /dev/null +++ b/3.test_cases/neuronx-distributed/mingpt/tests/test_huggingface_import.py @@ -0,0 +1,57 @@ +""" +Ensure that we can load huggingface/transformer GPTs into minGPT +""" + +import unittest +import torch +from transformers import GPT2Tokenizer, GPT2LMHeadModel +from mingpt.model import GPT +from mingpt.bpe import BPETokenizer +# ----------------------------------------------------------------------------- + +class TestHuggingFaceImport(unittest.TestCase): + + def test_gpt2(self): + model_type = 'gpt2' + device = 'cuda' if torch.cuda.is_available() else 'cpu' + prompt = "Hello!!!!!!!!!? 🤗, my dog is a little" + + # create a minGPT and a huggingface/transformers model + model = GPT.from_pretrained(model_type) + model_hf = GPT2LMHeadModel.from_pretrained(model_type) # init a HF model too + + # ship both to device + model.to(device) + model_hf.to(device) + + # set both to eval mode + model.eval() + model_hf.eval() + + # tokenize input prompt + # ... with mingpt + tokenizer = BPETokenizer() + x1 = tokenizer(prompt).to(device) + # ... with huggingface/transformers + tokenizer_hf = GPT2Tokenizer.from_pretrained(model_type) + model_hf.config.pad_token_id = model_hf.config.eos_token_id # suppress a warning + encoded_input = tokenizer_hf(prompt, return_tensors='pt').to(device) + x2 = encoded_input['input_ids'] + + # ensure the logits match exactly + logits1, loss = model(x1) + logits2 = model_hf(x2).logits + self.assertTrue(torch.allclose(logits1, logits2)) + + # now draw the argmax samples from each + y1 = model.generate(x1, max_new_tokens=20, do_sample=False)[0] + y2 = model_hf.generate(x2, max_new_tokens=20, do_sample=False)[0] + self.assertTrue(torch.equal(y1, y2)) # compare the raw sampled indices + + # convert indices to strings + out1 = tokenizer.decode(y1.cpu().squeeze()) + out2 = tokenizer_hf.decode(y2.cpu().squeeze()) + self.assertTrue(out1 == out2) # compare the exact output strings too + +if __name__ == '__main__': + unittest.main() diff --git a/3.test_cases/neuronx-distributed/mingpt/tutorials/ddp_neuron.py b/3.test_cases/neuronx-distributed/mingpt/tutorials/ddp_neuron.py new file mode 100644 index 00000000..19a28c80 --- /dev/null +++ b/3.test_cases/neuronx-distributed/mingpt/tutorials/ddp_neuron.py @@ -0,0 +1,104 @@ +import os +import time +import torch + +from torchvision.datasets import mnist +from torch.utils.data import DataLoader +from torchvision.transforms import ToTensor + +# XLA imports +import torch_xla.core.xla_model as xm + +# XLA imports for parallel loader and multi-processing +import torch_xla.distributed.parallel_loader as pl +from torch.utils.data.distributed import DistributedSampler + +# Initialize XLA process group for torchrun +import torch_xla.distributed.xla_backend +torch.distributed.init_process_group('xla') + +# Global constants +EPOCHS = 4 +WARMUP_STEPS = 2 +BATCH_SIZE = 32 + +import torch.nn as nn +import torch.nn.functional as F + +# Declare 3-layer MLP for MNIST dataset +class MLP(nn.Module): + def __init__(self, input_size = 28 * 28, output_size = 10, layers = [120, 84]): + super(MLP, self).__init__() + self.fc1 = nn.Linear(input_size, layers[0]) + self.fc2 = nn.Linear(layers[0], layers[1]) + self.fc3 = nn.Linear(layers[1], output_size) + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return F.log_softmax(x, dim=1) + +# Load MNIST train dataset +if not xm.is_master_ordinal(): xm.rendezvous('dataset_download') +train_dataset = mnist.MNIST(root='/tmp/MNIST_DATA_train', + train=True, download=True, transform=ToTensor()) +if xm.is_master_ordinal(): xm.rendezvous('dataset_download') + +def main(): + # XLA MP: get world size + world_size = xm.xrt_world_size() + # multi-processing: ensure each worker has same initial weights + torch.manual_seed(0) + + # Move model to device and declare optimizer and loss function + device = 'xla' + model = MLP().to(device) + # For multiprocessing, scale up learning rate + optimizer = torch.optim.SGD(model.parameters(), lr=0.01 * world_size) + loss_fn = torch.nn.NLLLoss() + + # Prepare data loader + train_sampler = None + if world_size > 1: + train_sampler = DistributedSampler(train_dataset, + num_replicas=world_size, + rank=xm.get_ordinal(), + shuffle=True) + train_loader = DataLoader(train_dataset, + batch_size=BATCH_SIZE, + sampler=train_sampler, + shuffle=False if train_sampler else True) + # XLA MP: use MpDeviceLoader from torch_xla.distributed + train_device_loader = pl.MpDeviceLoader(train_loader, device) + + # Run the training loop + print('----------Training ---------------') + model.train() + for epoch in range(EPOCHS): + start = time.time() + for idx, (train_x, train_label) in enumerate(train_device_loader): + optimizer.zero_grad() + train_x = train_x.view(train_x.size(0), -1) + output = model(train_x) + loss = loss_fn(output, train_label) + loss.backward() + xm.optimizer_step(optimizer) # XLA MP: performs grad allreduce and optimizer step + if idx < WARMUP_STEPS: # skip warmup iterations + start = time.time() + + # Compute statistics for the last epoch + interval = idx - WARMUP_STEPS # skip warmup iterations + throughput = interval / (time.time() - start) + print("Train throughput (iter/sec): {}".format(throughput)) + print("Final loss is {:0.4f}".format(loss.detach().to('cpu'))) + + # Save checkpoint for evaluation (xm.save ensures only one process save) + os.makedirs("checkpoints", exist_ok=True) + checkpoint = {'state_dict': model.state_dict()} + xm.save(checkpoint,'checkpoints/checkpoint.pt') + + print('----------End Training ---------------') + +if __name__ == '__main__': + main() diff --git a/3.test_cases/neuronx-distributed/mingpt/tutorials/main.py b/3.test_cases/neuronx-distributed/mingpt/tutorials/main.py new file mode 100644 index 00000000..ae15cc64 --- /dev/null +++ b/3.test_cases/neuronx-distributed/mingpt/tutorials/main.py @@ -0,0 +1,59 @@ +import math + +import torch +import torch.nn as nn +from torch.nn import functional as F + +class NewGELU(nn.Module): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). + Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 + """ + def forward(self, x): + return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class CausalSelfAttention(nn.Module): + """ + A vanilla multi-head masked self-attention layer with a projection at the end. + It is possible to use torch.nn.MultiheadAttention here but I am including an + explicit implementation here to show that there is nothing too scary here. + """ + + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd) + # regularization + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) + .view(1, 1, config.block_size, config.block_size)) + self.n_head = config.n_head + self.n_embd = config.n_embd + + def forward(self, x): + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k ,v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + diff --git a/3.test_cases/neuronx-distributed/mingpt/tutorials/model_paralell_neuron.py b/3.test_cases/neuronx-distributed/mingpt/tutorials/model_paralell_neuron.py new file mode 100644 index 00000000..e69de29b diff --git a/3.test_cases/neuronx-distributed/olmo/configs/OLMo-1B.yaml b/3.test_cases/neuronx-distributed/olmo/configs/OLMo-1B.yaml new file mode 100644 index 00000000..c04f0a7f --- /dev/null +++ b/3.test_cases/neuronx-distributed/olmo/configs/OLMo-1B.yaml @@ -0,0 +1,446 @@ +run_name: OLMo-1B +seed: 6198 +dry_run: false + +wandb: + name: ${run_name} + project: olmo-small + +model: + d_model: 2048 + n_heads: 16 + n_layers: 16 + mlp_ratio: 8 + weight_tying: true + alibi: true + rope: false + flash_attention: false # not available on AMD + attention_dropout: 0.0 + attention_layer_norm: false + multi_query_attention: false + include_bias: false + block_type: sequential + layer_norm_type: default + layer_norm_with_affine: false + bias_for_layer_norm: false + attention_layer_norm_with_affine: false + activation_type: swiglu + residual_dropout: 0.0 + embedding_dropout: 0.0 + max_sequence_length: 2048 + vocab_size: 50280 + embedding_size: 50304 + eos_token_id: 50279 + pad_token_id: 1 + init_device: meta + init_fn: mitchell + +compile: null # causes instability on AMD GPUs + +optimizer: + name: adamw + learning_rate: 4.0e-4 + weight_decay: 0.1 + betas: + - 0.9 + - 0.95 + metrics_log_interval: 10 + +scheduler: + name: cosine_with_warmup + t_warmup: 2000 + alpha_f: 0.1 + +tokenizer: + identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json + truncate_direction: right +# +# save_folder: ${path.choose:${oc.env:SCRATCH_DIR,no_exist}/checkpoints,/results}/${oc.env:SLURM_JOB_ID,${run_name}} +# save_overwrite: false +# # Sharded checkpoints (best for restarts) +# save_interval: 1000 +# save_num_checkpoints_to_keep: 9 +# # Unsharded checkpoints (for final storage) +# save_interval_unsharded: 10000 +# save_num_unsharded_checkpoints_to_keep: -1 +# +# load_path: null +# +# max_duration: 739_328 # 3.1T tokens +# global_train_batch_size: 2048 +# device_train_microbatch_size: 8 +# +# precision: amp_bf16 +# +# fsdp: +# wrapping_strategy: null +# precision: mixed +# +# max_grad_norm: 1.0 +# max_grad_norm_ratio: null +# +# speed_monitor: +# window_size: 20 +# +# eval_interval: ${save_interval} +# eval_subset_num_batches: -1 +# device_eval_batch_size: ${device_train_microbatch_size} +# evaluators: +# # lump all the small datasets together (we still get separate metrics). +# - label: v3-small-ppl-validation +# data: +# num_workers: 0 +# drop_last: true +# datasets: +# v3-small-c4_en-validation: +# - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/c4_en/val/part-0-00000.npy +# v3-small-dolma_books-validation: +# - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_books/val/part-0-00000.npy +# v3-small-dolma_common-crawl-validation: +# - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_common-crawl/val/part-0-00000.npy +# v3-small-dolma_pes2o-validation: +# - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_pes2o/val/part-0-00000.npy +# v3-small-dolma_reddit-validation: +# - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_reddit/val/part-0-00000.npy +# v3-small-dolma_stack-validation: +# - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_stack/val/part-0-00000.npy +# v3-small-dolma_wiki-validation: +# - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_wiki/val/part-0-00000.npy +# v3-small-ice-validation: +# - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/ice/val/part-0-00000.npy +# v3-small-m2d2_s2orc-validation: +# - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/m2d2_s2orc/val/part-0-00000.npy +# v3-small-pile-validation: +# - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/pile/val/part-0-00000.npy +# v3-small-wikitext_103-validation: +# - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/wikitext_103/val/part-0-00000.npy +# +# - label: v2-small-ppl-validation +# data: +# num_workers: 0 +# drop_last: true +# datasets: +# v2-small-4chan-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/4chan/val.npy +# v2-small-c4_100_domains-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/c4_100_domains/val.npy +# v2-small-c4_en-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/c4_en/val.npy +# v2-small-gab-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/gab/val.npy +# v2-small-ice-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/ice/val.npy +# v2-small-m2d2_s2orc-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/m2d2_s2orc/val.npy +# v2-small-m2d2_wiki-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/m2d2_wiki/val.npy +# v2-small-manosphere-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/manosphere/val.npy +# v2-small-mc4_en-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/mc4_en/val.npy +# v2-small-pile-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/pile/val.npy +# v2-small-ptb-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/ptb/val.npy +# v2-small-twitterAEE-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/twitterAEE/val.npy +# v2-small-wikitext_103-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/wikitext_103/val.npy +# +# - label: piqa +# type: downstream +# +# - label: hellaswag +# type: downstream +# +# - label: winogrande +# type: downstream +# +# - label: openbook_qa +# type: downstream +# +# # - label: boolq # requires implemention of the pmi_dc matrix +# # type: downstream +# +# - label: sciq +# type: downstream +# +# - label: arc_easy +# type: downstream +# +# # - label: arc_challenge # requires implemention of the pmi_dc matrix +# # type: downstream +# +# - label: copa +# type: downstream +# +# - label: rte +# type: downstream +# +# - label: commitment_bank +# type: downstream +# +# - label: mrpc +# type: downstream +# +# - label: sst2 +# type: downstream +# +# data: +# pad_direction: right +# num_workers: 0 +# drop_last: true +# pin_memory: true +# prefetch_factor: 16 +# persistent_workers: true +# timeout: 0 +# paths: +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-000-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-000-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-001-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-002-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-003-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-004-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-004-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-005-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-005-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-006-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-006-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-007-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-008-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-008-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-009-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-009-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-010-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-010-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-011-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-012-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-013-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-014-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-015-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-016-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-017-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-017-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-018-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-018-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-019-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-020-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-020-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-021-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-022-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-023-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-024-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-025-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-025-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-026-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-026-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-027-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-027-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-028-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-029-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-030-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-031-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-032-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-033-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-033-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-034-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-034-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-035-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-035-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-036-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-036-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-037-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-038-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-039-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-039-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-040-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-041-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-042-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-043-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-044-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-045-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-045-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-046-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-047-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-047-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-048-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-049-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-050-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-051-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-052-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-053-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-054-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-055-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-056-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-057-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-058-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-059-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-060-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-061-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-062-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-063-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-064-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-064-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-065-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-065-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-066-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-066-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-067-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-067-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-068-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-068-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-069-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-069-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-070-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-071-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-072-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-073-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-074-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-074-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-075-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-075-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-076-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-076-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-077-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-078-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-078-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-079-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-079-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-080-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-081-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-082-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-083-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-083-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-084-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-085-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-086-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-087-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-088-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-088-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-089-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-089-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-090-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-090-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-091-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-092-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-093-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-094-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-095-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-096-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-096-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-097-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-098-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-099-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-100-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-101-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-102-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-102-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-103-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-104-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-105-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-105-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-106-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-107-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-108-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-109-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-110-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-111-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-112-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-112-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-113-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-114-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-115-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-116-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-117-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-118-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-118-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-119-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-120-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-120-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-121-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-122-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-123-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-124-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-125-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-126-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-126-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-127-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-128-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-129-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-130-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-131-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-132-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-133-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-134-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-135-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-136-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-137-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-138-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-139-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-139-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-140-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-141-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-142-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-143-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-143-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-144-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-145-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-145-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-146-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-147-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-147-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-148-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-149-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-149-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-150-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-151-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-151-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-152-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-152-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-153-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-153-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-154-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-155-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-156-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-156-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-157-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-158-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-158-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-159-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-160-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-160-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-161-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-161-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-162-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-163-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-163-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-164-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-165-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-165-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-166-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-166-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-167-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-167-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-168-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-169-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-170-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-171-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-172-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-173-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-174-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-174-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-175-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-176-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-177-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-178-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-179-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-179-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-180-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-181-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-182-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-183-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-184-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-185-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-185-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-186-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-187-00000.npy \ No newline at end of file diff --git a/3.test_cases/neuronx-distributed/olmo/configs/OLMo-7B.yaml b/3.test_cases/neuronx-distributed/olmo/configs/OLMo-7B.yaml new file mode 100644 index 00000000..002891bc --- /dev/null +++ b/3.test_cases/neuronx-distributed/olmo/configs/OLMo-7B.yaml @@ -0,0 +1,648 @@ +run_name: OLMo-7B +seed: 6198 +dry_run: false + +wandb: + name: ${run_name} + project: olmo-medium + group: OLMo-7B + +model: + d_model: 4096 + n_heads: 32 + n_layers: 32 + mlp_hidden_size: 22016 + weight_tying: false + alibi: false + rope: true + flash_attention: true + attention_dropout: 0.0 + attention_layer_norm: false + multi_query_attention: false + include_bias: false + block_type: sequential + layer_norm_type: default + layer_norm_with_affine: false + bias_for_layer_norm: false + attention_layer_norm_with_affine: false + activation_type: swiglu + residual_dropout: 0.0 + embedding_dropout: 0.0 + max_sequence_length: 2048 + vocab_size: 50280 + embedding_size: 50304 + eos_token_id: 50279 + pad_token_id: 1 + init_device: meta + init_fn: mitchell + +compile: + fullgraph: false + +optimizer: + name: adamw + learning_rate: 3.0e-4 + weight_decay: 0.1 + betas: + - 0.9 + - 0.95 + metrics_log_interval: 10 + +scheduler: + name: linear_with_warmup + t_warmup: 5000 + alpha_f: 0.1 + grad_clip_warmup_steps: 1000 + grad_clip_warmup_factor: 10.0 + +#tokenizer: +# identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json +# truncate_direction: right +# +#save_folder: runs/${run_name} +#remote_save_folder: null +#save_overwrite: true +## Sharded checkpoints (best for restarts) +#save_interval: 1000 +#save_num_checkpoints_to_keep: -1 +## Unsharded checkpoints (for final storage) +#save_interval_unsharded: null +#save_num_unsharded_checkpoints_to_keep: -1 +# +#load_path: null +# +#max_duration: 2e12T # 2T tokens +#global_train_batch_size: 2048 +#device_train_microbatch_size: 2 +#time_limit: null +# +#precision: amp_bf16 +# +#fsdp: +# wrapping_strategy: by_block +# precision: mixed +# +#max_grad_norm: 1.0 +#max_grad_norm_ratio: null +# +#speed_monitor: +# window_size: 20 +# +#eval_interval: ${save_interval} +#eval_subset_num_batches: -1 +#device_eval_batch_size: ${device_train_microbatch_size} +#evaluators: +# - label: v3-small-ppl-validation +# data: +# num_workers: 0 +# drop_last: true +# datasets: +# v3-small-c4_en-validation: +# - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/c4_en/val/part-0-00000.npy +# v3-small-dolma_books-validation: +# - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_books/val/part-0-00000.npy +# v3-small-dolma_common-crawl-validation: +# - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_common-crawl/val/part-0-00000.npy +# v3-small-dolma_pes2o-validation: +# - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_pes2o/val/part-0-00000.npy +# v3-small-dolma_reddit-validation: +# - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_reddit/val/part-0-00000.npy +# v3-small-dolma_stack-validation: +# - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_stack/val/part-0-00000.npy +# v3-small-dolma_wiki-validation: +# - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_wiki/val/part-0-00000.npy +# v3-small-ice-validation: +# - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/ice/val/part-0-00000.npy +# v3-small-m2d2_s2orc-validation: +# - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/m2d2_s2orc/val/part-0-00000.npy +# v3-small-pile-validation: +# - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/pile/val/part-0-00000.npy +# v3-small-wikitext_103-validation: +# - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/wikitext_103/val/part-0-00000.npy +# +# - label: v2-small-ppl-validation +# data: +# num_workers: 0 +# drop_last: true +# datasets: +# v2-small-4chan-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/4chan/val.npy +# v2-small-c4_100_domains-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/c4_100_domains/val.npy +# v2-small-c4_en-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/c4_en/val.npy +# v2-small-gab-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/gab/val.npy +# v2-small-ice-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/ice/val.npy +# v2-small-m2d2_s2orc-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/m2d2_s2orc/val.npy +# v2-small-m2d2_wiki-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/m2d2_wiki/val.npy +# v2-small-manosphere-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/manosphere/val.npy +# v2-small-mc4_en-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/mc4_en/val.npy +# v2-small-pile-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/pile/val.npy +# v2-small-ptb-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/ptb/val.npy +# v2-small-twitterAEE-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/twitterAEE/val.npy +# v2-small-wikitext_103-validation: +# - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/wikitext_103/val.npy +# +# ########################## +# # Downstream evaluations # +# ########################## +# - label: piqa +# type: downstream +# +# - label: hellaswag +# type: downstream +# +# - label: winogrande +# type: downstream +# +# - label: openbook_qa +# type: downstream +# +# # - label: boolq # requires implemention of the pmi_dc matrix +# # type: downstream +# +# - label: sciq +# type: downstream +# +# - label: arc_easy +# type: downstream +# +# # - label: arc_challenge # requires implemention of the pmi_dc matrix +# # type: downstream +# +# - label: copa +# type: downstream +# +# - label: rte +# type: downstream +# +# - label: commitment_bank +# type: downstream +# +# - label: mrpc +# type: downstream +# +# - label: sst2 +# type: downstream +# +#data: +# pad_direction: right +# num_workers: 16 +# drop_last: true +# pin_memory: true +# prefetch_factor: 1 +# persistent_workers: true +# timeout: 0 +# paths: +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-000-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-000-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-001-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-001-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-002-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-002-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-003-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-003-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-004-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-004-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-005-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-005-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-006-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-006-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-006-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-007-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-007-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-008-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-008-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-008-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-009-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-009-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-010-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-010-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-010-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-011-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-011-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-012-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-012-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-013-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-013-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-013-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-014-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-014-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-014-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-015-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-015-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-016-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-016-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-017-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-017-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-018-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-018-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-019-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-019-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-020-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-020-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-021-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-021-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-022-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-022-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-023-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-023-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-024-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-024-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-025-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-025-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-025-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-026-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-026-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-027-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-027-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-027-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-028-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-028-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-028-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-029-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-029-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-030-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-030-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-031-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-031-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-032-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-032-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-033-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-033-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-033-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-034-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-034-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-034-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-035-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-035-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-036-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-036-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-037-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-037-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-038-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-038-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-039-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-039-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-040-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-040-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-041-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-041-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-042-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-042-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-042-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-043-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-043-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-043-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-044-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-044-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-044-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-045-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-045-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-046-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-046-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-046-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-046-00003.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-047-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-047-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-048-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-048-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-049-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-049-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-050-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-050-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-051-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-051-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-052-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-052-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-052-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-053-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-053-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-053-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-054-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-054-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-055-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-055-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-055-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-056-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-056-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-056-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-057-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-057-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-057-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-058-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-058-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-059-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-059-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-060-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-060-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-061-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-061-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-062-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-062-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-062-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-063-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-063-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-063-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-064-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-064-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-064-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-065-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-065-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-065-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-066-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-066-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-067-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-067-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-068-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-068-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-069-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-069-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-070-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-070-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-071-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-071-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-072-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-072-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-073-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-073-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-074-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-074-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-075-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-075-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-076-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-076-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-077-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-077-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-078-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-078-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-079-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-079-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-080-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-080-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-081-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-081-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-082-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-082-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-083-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-083-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-084-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-084-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-085-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-085-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-086-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-086-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-087-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-087-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-088-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-088-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-089-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-089-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-089-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-090-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-090-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-091-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-091-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-091-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-092-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-092-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-093-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-093-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-093-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-094-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-094-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-094-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-095-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-095-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-096-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-096-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-097-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-097-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-097-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-098-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-098-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-099-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-099-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-100-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-100-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-100-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-101-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-101-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-102-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-102-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-103-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-103-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-104-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-104-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-105-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-105-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-106-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-106-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-106-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-107-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-107-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-108-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-108-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-109-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-109-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-109-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-110-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-110-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-110-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-111-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-111-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-112-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-112-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-113-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-113-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-114-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-114-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-114-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-115-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-115-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-116-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-116-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-117-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-117-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-118-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-118-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-119-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-119-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-120-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-120-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-120-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-121-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-121-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-122-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-122-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-122-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-123-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-123-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-123-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-124-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-124-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-125-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-125-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-126-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-126-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-127-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-127-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-127-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-128-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-128-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-129-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-129-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-129-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-130-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-130-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-131-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-131-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-132-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-132-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-133-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-133-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-133-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-134-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-134-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-134-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-135-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-135-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-135-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-136-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-136-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-137-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-137-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-137-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-138-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-138-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-139-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-139-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-140-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-140-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-141-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-141-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-141-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-142-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-142-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-142-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-143-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-143-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-144-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-144-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-144-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-145-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-145-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-145-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-146-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-146-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-146-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-147-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-147-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-147-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-148-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-148-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-149-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-149-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-149-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-150-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-150-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-150-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-150-00003.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-151-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-151-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-152-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-152-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-153-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-153-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-154-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-154-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-155-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-155-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-155-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-156-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-156-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-157-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-157-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-157-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-158-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-158-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-159-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-159-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-160-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-160-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-161-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-161-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-161-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-162-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-162-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-163-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-163-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-164-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-164-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-165-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-165-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-165-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-166-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-166-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-166-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-167-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-167-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-167-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-168-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-168-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-169-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-169-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-170-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-170-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-171-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-171-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-172-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-172-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-173-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-173-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-173-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-174-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-174-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-174-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-175-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-175-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-175-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-176-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-176-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-176-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-177-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-177-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-178-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-178-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-179-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-179-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-180-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-180-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-181-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-181-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-182-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-182-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-182-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-183-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-183-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-183-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-184-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-184-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-185-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-185-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-185-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-186-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-186-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-186-00002.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-187-00000.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-187-00001.npy +# - https://olmo-data.org/preprocessed/olmo-mix/v1_5-sample/gpt-neox-20b-pii-special/part-187-00002.npy \ No newline at end of file diff --git a/3.test_cases/neuronx-distributed/olmo/olmo/aliases.py b/3.test_cases/neuronx-distributed/olmo/olmo/aliases.py new file mode 100644 index 00000000..f9f9b1a3 --- /dev/null +++ b/3.test_cases/neuronx-distributed/olmo/olmo/aliases.py @@ -0,0 +1,7 @@ +from os import PathLike +from typing import Union + +__all__ = ["PathOrStr"] + + +PathOrStr = Union[str, PathLike] \ No newline at end of file diff --git a/3.test_cases/neuronx-distributed/olmo/olmo/beam_search.py b/3.test_cases/neuronx-distributed/olmo/olmo/beam_search.py new file mode 100644 index 00000000..fdcaee31 --- /dev/null +++ b/3.test_cases/neuronx-distributed/olmo/olmo/beam_search.py @@ -0,0 +1,1078 @@ +""" +This is a self-contained and flexible beam search implementation adapted from +AllenNLP's beam search: https://github.com/allenai/allennlp/blob/main/allennlp/nn/beam_search.py +""" + +import copy +import warnings +from abc import abstractmethod +from inspect import signature +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast + +import torch + +__all__ = [ + "Sampler", + "DeterministicSampler", + "MultinomialSampler", + "TopKSampler", + "TopPSampler", + "GumbelSampler", + "FinalSequenceScorer", + "SequenceLogProbabilityScorer", + "LengthNormalizedSequenceLogProbabilityScorer", + "Constraint", + "RepeatedNGramBlockingConstraint", + "BeamSearch", +] + +StateType = Dict[str, torch.Tensor] +StepFunctionTypeWithTimestep = Callable[[torch.Tensor, StateType, int], Tuple[torch.Tensor, StateType]] +StepFunctionTypeNoTimestep = Callable[[torch.Tensor, StateType], Tuple[torch.Tensor, StateType]] + +StepFunctionType = TypeVar("StepFunctionType", StepFunctionTypeWithTimestep, StepFunctionTypeNoTimestep) +""" +The type of step function that can be passed to [`BeamSearch.search`](#search). + +This can either be [`StepFunctionTypeWithTimestep`](#stepfunctiontypewithtimestep) +or [`StepFunctionTypeNoTimestep`](#stepfunctiontypenotimestep). +""" + +ConstraintStateType = List[List[Dict[str, Any]]] + + +class Sampler: + """ + An abstract class that can be used to sample candidates (either nodes or beams) + within `BeamSearch`. + + A `Sampler` just has three methods, `init_state()`, `sample_nodes()` and `sample_beams()`. + + `init_state()` takes three arguments: + + - a tensor of starting log probs with shape `(batch_size,, num_classes)`, + - the batch size, an int, + - and the number of classes, also an int. + + It returns a state dictionary with any state tensors needed for subsequent + calls to `sample_nodes()` and `sample_beams()`. + + By default this method just returns an empty dictionary. + + Both `sample_nodes()` and `sample_beams()` should take three arguments: + + - tensor of normalized log probabilities with shape `(batch_size, num_examples)`, + - an integer representing the number of samples to take for each example in the batch, + - and a state dictionary which could contain any tensors needed for the `Sampler` to keep + track of state. + + For `sample_nodes()`, `num_examples = num_classes`, but for `sample_beams`, + `num_examples = beam_size * per_node_beam_size`. + + The return value should be a tuple containing: + + - a tensor of log probabilities of the sampled examples with shape `(batch_size, num_samples)`, + - a tensor of indices of the sampled examples with shape `(batch_size, num_samples)`, + - and the updated state dictionary. + + A default implementation of `sample_beams` is provided, which just deterministically + picks the `k` examples with highest log probability. + """ + + def init_state( + self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int + ) -> StateType: + del start_class_log_probabilities, batch_size, num_classes + return {} + + @abstractmethod + def sample_nodes( + self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType + ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: + raise NotImplementedError + + def sample_beams( + self, log_probs: torch.Tensor, beam_size: int, state: StateType + ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: + del state + selected_log_probs, selected_indices = torch.topk(log_probs, beam_size, dim=-1) + return selected_log_probs, selected_indices, {} + + +class DeterministicSampler(Sampler): + """ + A `Sampler` that just deterministically returns the `k` nodes or beams with highest + log probability. + """ + + def sample_nodes( + self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType + ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: + del state + selected_log_probs, selected_indices = torch.topk(log_probs, per_node_beam_size, dim=-1) + return selected_log_probs, selected_indices, {} + + +class MultinomialSampler(Sampler): + """ + A `Sampler` which samples nodes from the given multinomial distribution. Beams are sampled + in the default, non-deterministic way. + + :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature` + above 1.0 produces a flatter probability distribution. + :param with_replacement: Whether to sample with replacement. + + """ + + def __init__( + self, + temperature: float = 1.0, + with_replacement: bool = False, + ) -> None: + self.temperature = temperature + self.with_replacement = with_replacement + + def sample_nodes( + self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType + ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: + if self.temperature != 1.0: + _probabilities = torch.nn.functional.softmax(log_probs / self.temperature, dim=-1) + else: + _probabilities = log_probs.exp() + + selected_indices = torch.multinomial(_probabilities, per_node_beam_size, replacement=self.with_replacement) + + return torch.gather(log_probs, 1, selected_indices), selected_indices, state + + +class TopKSampler(Sampler): + """ + A `Sampler` which redistributes the probability mass function for nodes among the + top `k` choices, then samples from that subset after re-normalizing the probabilities. + + Beams are sampled in the default, deterministic way. + + :param k: The number of top choices to be selected from. + :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature` + above 1.0 produces a flatter probability distribution. + :param with_replacement: If set to `True`, samples will be selected with replacement from the top k choices. + """ + + def __init__( + self, + k: int = 1, + temperature: float = 1.0, + with_replacement: bool = False, + ): + self.k = k + self.temperature = temperature or 1.0 + self.with_replacement = with_replacement + + def sample_nodes( + self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType + ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: + if not per_node_beam_size <= self.k <= log_probs.size()[1]: + raise ValueError( + "k must be a postive integer no less than per_node_beam_size and no greater than vocabulary size" + ) + + # shape (both): (batch_size, k) + top_k_log_probs, top_k_indices = log_probs.topk(self.k, dim=-1) + + # Apply temperature if necessary. + # shape: (batch_size, k) + if self.temperature != 1.0: + top_k_log_probs = top_k_log_probs / self.temperature + + # Re-normalize the subset. + # shape: (batch_size, k) + normalized_top_k_probs = torch.nn.functional.softmax(top_k_log_probs, dim=-1) + + # Sample from the re-normalized subset. + # NOTE: These indices are not indices into `log_probs`, they are indices into `top_k_log_probs`. + # shape: (batch_size, per_node_beam_size) + sampled_indices = torch.multinomial( + normalized_top_k_probs, per_node_beam_size, replacement=self.with_replacement + ) + + # Convert `sampled_indices` back to indices in the original `log_probs` tensor. + # shape: (batch_size, per_node_beam_size) + indices = top_k_indices.gather(-1, sampled_indices) + + return log_probs.gather(1, indices), indices, state + + +class TopPSampler(Sampler): + """ + A `Sampler` which redistributes the probability mass function for nodes among + the top choices with a cumulative probability of at least `p`, then samples from that subset + after re-normalizing the probabilities. + + Beams are sampled in the default, deterministic way. + + :param p: + The cumulative probability cutoff threshold. A higher value of `p` will result in more possible + examples to sample from. If `with_replacement` is `False` and the number of possible samples is + insufficient to sample without replacement from when calling `sample_nodes`, then the top + `per_node_beam_size` examples will be chosen. + :param temperature: + A `temperature` below 1.0 produces a sharper probability distribution and a `temperature` + above 1.0 produces a flatter probability distribution. + :param with_replacement: + If set to `True`, samples will be selected with replacement from the top choices. + + """ + + def __init__( + self, + p: float = 0.9, + temperature: float = 1.0, + with_replacement: bool = False, + ): + if p < 0.0 or p > 1.0: + raise ValueError("p must be a positive float no greater than 1.0") + self.p = p + self.temperature = temperature or 1.0 + self.with_replacement = with_replacement + + def sample_nodes( + self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType + ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: + if not per_node_beam_size <= log_probs.size()[1]: + raise ValueError("per_node_beam_size cannot be greater than vocabulary size") + + # First apply temperature coefficient: + if self.temperature != 1.0: + _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1) + else: + _log_probs = log_probs + + # Sort the probabilities in descending order to then find cumulative sum + log_probs_descending, sorting_indices = torch.sort(_log_probs, descending=True) + + # shape: (batch_size, num_classes) + probabilities_descending = log_probs_descending.exp() + probabilities_summed = torch.cumsum(probabilities_descending, dim=-1) + + # Create a mask for filtering out probabilities that don't make the top `p`. + # shape: (batch_size, num_classes) + exclusion_mask = probabilities_summed >= self.p + + # We want to include the first index where probabilities_summed >= p, so we shift over one. + exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone() + exclusion_mask[..., 0] = False + + # Make sure there's at least `per_node_beam_size` options to be selected. + if not self.with_replacement: + exclusion_mask[..., :per_node_beam_size] = False + + log_probs_descending[exclusion_mask] = torch.finfo(log_probs.dtype).min + + # Now re-normalized the included log probs. + # shape: (batch_size, num_classes) + filtered_probabilities = torch.nn.functional.softmax(log_probs_descending, dim=-1) + + # Sample from the re-normalized subset. + # NOTE: These indices are not indices into `log_probs`, they are indices into `log_probs_descending`. + # shape: (batch_size, per_node_beam_size) + sampled_indices = torch.multinomial( + filtered_probabilities, per_node_beam_size, replacement=self.with_replacement + ) + + # Convert `sampled_indices` back to indices in the original `log_probs` tensor. + # shape: (batch_size, per_node_beam_size) + selected_indices = sorting_indices.gather(-1, sampled_indices) + + # Return (selected log probabilities, selected classes) + # shape: (len(log_probs),1) , (len(log_probs), 1) + return torch.gather(log_probs, 1, selected_indices), selected_indices, state + + +class GumbelSampler(Sampler): + """ + A `Sampler` which uses the Gumbel-Top-K trick to sample without replacement. See + [*Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling + Sequences Without Replacement*, W Kool, H Van Hoof and M Welling, 2010] + (https://api.semanticscholar.org/CorpusID:76662039). + + :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature` + above 1.0 produces a flatter probability distribution. + """ + + def __init__(self, temperature: float = 1.0): + self.temperature = temperature + + def init_state( + self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int + ) -> StateType: + # shape: (batch_size, num_classes) + zeros = start_class_log_probabilities.new_zeros((batch_size, num_classes)) + + # shape: (batch_size, num_classes) + G_phi_S = self.gumbel_with_max(start_class_log_probabilities, zeros) + + return {"G_phi_S": G_phi_S} + + def sample_nodes( + self, + log_probs: torch.Tensor, + per_node_beam_size: int, + state: StateType, + ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: + # First apply temperature coefficient: + # shape: (batch_size * beam_size, num_classes) + if self.temperature != 1.0: + _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1) + else: + _log_probs = log_probs + + # shape: (group_size,) + phi_S = state["phi_S"] + + # shape: (group_size, num_classes) + phi_S = phi_S.unsqueeze(-1).expand_as(_log_probs) + + # shape: (group_size, num_classes) + phi_S_new = phi_S + _log_probs + + # shape: (group_size, 1) + G_phi_S = state["G_phi_S"].unsqueeze(-1) + + # shape: (group_size, num_classes) + G_phi_S_new = self.gumbel_with_max(phi_S_new, G_phi_S) + + # Replace NaNs with very negative number. + # shape: (group_size, num_classes) + # G_phi_S_new[G_phi_S_new.isnan()] = torch.finfo(G_phi_S_new.dtype).min + + # shape (both): (group_size, per_node_beam_size) + top_G_phi_S_new, top_indices = torch.topk(G_phi_S_new, per_node_beam_size, dim=-1) + + # shape: (group_size, per_node_beam_size) + top_log_probs = log_probs.gather(1, top_indices) + + return top_log_probs, top_indices, {"G_phi_S": top_G_phi_S_new} + + def sample_beams( + self, + log_probs: torch.Tensor, + beam_size: int, + state: StateType, + ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: + """ + Returns the beams with the highest perturbed log probabilities. + """ + # shape (log_probs): (batch_size, beam_size * per_node_beam_size) + + batch_size = log_probs.size()[0] + + # shape: (batch_size * beam_size, per_node_beam_size) + G_phi_S = state["G_phi_S"] + + # shape: (batch_size, beam_size * per_node_beam_size) + G_phi_S = G_phi_S.reshape_as(log_probs) + + # shape (both): (batch_size, beam_size) + G_phi_S_new, selected_indices = torch.topk(G_phi_S, beam_size, dim=-1) + + # shape: (batch_size, beam_size) + selected_log_probs = log_probs.gather(1, selected_indices) + + # Now sort the selected beams by their true log prob. + # shape (all): (batch_size, beam_size) + selected_log_probs, sort_indices = selected_log_probs.sort(dim=-1, descending=True) + selected_indices = selected_indices.gather(1, sort_indices) + G_phi_S_new = G_phi_S_new.gather(1, sort_indices) + + # shape: (batch_size * beam_size,) + G_phi_S_new = G_phi_S_new.reshape(batch_size * beam_size) + + # shape: (batch_size * beam_size,) + phi_S = selected_log_probs.reshape(batch_size * beam_size) + + return selected_log_probs, selected_indices, {"G_phi_S": G_phi_S_new, "phi_S": phi_S} + + def gumbel(self, phi) -> torch.Tensor: + """ + Sample `Gumbel(phi)`. + + `phi` should have shape `(batch_size, num_classes)`. + """ + return -torch.log(-torch.log(torch.rand_like(phi))) + phi + + def gumbel_with_max(self, phi, T) -> torch.Tensor: + """ + Sample `Gumbel(phi)` conditioned on the maximum value being equal to `T`. + + `phi` should have shape `(batch_size, num_classes)` and `T` should have + shape `(batch_size, 1)`. + """ + # Shape: (batch_size, num_classes) + G_phi = self.gumbel(phi) + + # Now we find the maximum from these samples. + # Shape: (batch_size, ) + Z, _ = G_phi.max(dim=-1) + + # Shape: (batch_size, num_classes) + v = T - G_phi + torch.log1p(-torch.exp(G_phi - Z.unsqueeze(-1))) + + # Shape: (batch_size, num_classes) + return T - torch.nn.functional.relu(v) - torch.log1p(torch.exp(-v.abs())) + + +class FinalSequenceScorer: + """ + An abstract class that can be used to score the final generated sequences found + by beam search. Given the predicted sequences and the corresponding log probabilities of + those sequences, the class calculates and returns the final score of the sequences. + + The default implementation scores the sequences using the sum of the log probabilities of + the sequence, which is passed as input. + """ + + @abstractmethod + def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor: + """ + Score the final predictions found by beam search. + Returns a tensor of the final sequence scores of shape `(batch_size, beam_size)`. + + :param predictions: A tensor containing the initial predictions with shape `(batch_size, beam_size, max_steps)`. + :param log_probabilities: A tensor containing the log probabilities of the sequence, defined as the sum + of the log probabilities per token, with shape `(batch_size, beam_size)`. + :param end_index: The index of the end symbol. + + """ + raise NotImplementedError + + +class SequenceLogProbabilityScorer(FinalSequenceScorer): + """ + A :class:`FinalSequenceScorer` which scores the sequences by the sum of the log probabilities + across the sequence's tokens. + """ + + def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor: + del predictions, end_index + # The sum of the sequence log probabilities is the input parameter, so just + # return it. + return log_probabilities + + +class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer): + """ + A :class:`FinalSequenceScorer` which scores the sequences by the average log probability of the + tokens in the sequence. It optionally includes a length penalty which promotes + or demotes sequences based on their lengths. The final score for a sequence will + be `(sequence_log_probability) / (sequence_length ** length_penalty)`. The sequence length + here includes the end token. + + :param length_penalty: The length penalty to use. A value of 1.0 means no length penalty is used. + A value > 1.0 favors longer sequences, and < 1.0 favors shorter sequences. + """ + + def __init__(self, length_penalty: float = 1.0): + super().__init__() + self.length_penalty = length_penalty + + def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor: + # shape: (batch_size, beam_size) + lengths = (predictions != end_index).long().sum(dim=2) + + # If the sequence ended during beam search, the `log_probabilities` will include + # the transition to the end token. Therefore, in such situations, `lengths` is + # actually off by 1. This corrects for that. + # shape: (batch_size, beam_size) + is_end_token = predictions[:, :, -1] == end_index + lengths += is_end_token.long() + + # shape: (batch_size, beam_size) + average_log_probs = log_probabilities / (lengths**self.length_penalty) + return average_log_probs + + +class Constraint: + """ + An abstract class that can be used to enforce constraints on the output predictions + by manipulating the class log probabilities during beam search. + + A `Constraint` just has three methods that need to be implemented by subclasses: + `init_state()`, `apply()` and `_update_state()`. + + `init_state()` takes one argument: + + - the batch size, an int + + It returns a constraint state, which is a nested list of dictionaries, with any state needed for subsequent + calls to `apply()` and `update_state()`. The length of the outer list should be equal to `batch_size`. + Each inner list should be of length 1. + + `apply()` takes two arguments: + + - the constraint state, which is a nested list of dictionaries. The length of the outer list is `batch_size` + and the length of each inner list is `beam_size` except on the first time `apply()` is called when it is 1. + - `class_log_probabilities`, a tensor of shape `(batch_size, beam_size, num_classes)` that contains the + log probabilities for the classes during search. The first time `apply()` is called, `beam_size = 1`. + + The `apply()` method should return new `class_log_probabilities` that enforce the constraint + for this step of beam search. For instance, it may prevent a specific class from being selected by setting + the corresponding log probability to a negligible value such as `float("-inf")` or + `torch.finfo(class_log_probabilities.dtype).min`. + + `_update_state()` takes two arguments: + + - the copied parent constraint state, which is a nested list of dictionaries. `state[i][j]` contains the + copied state for the parent of `last_prediction[i, j]`. It is unique to that batch and beam, so it can be + directly edited in-place without affecting the others. + - last_prediction, a tensor of shape `(batch_size, beam_size)` containing the predictions from the last + step of beam search. + + The `_update_state()` function should return a new constraint state, a nested list of dictionaries of + length `batch_size` and inner list of length `beam_size`, one for each of the predictions in `last_prediction`. + + """ + + @abstractmethod + def init_state( + self, + batch_size: int, + ) -> ConstraintStateType: + raise NotImplementedError + + @abstractmethod + def apply( + self, + state: ConstraintStateType, + class_log_probabilities: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + @staticmethod + def _copy_state( + state: ConstraintStateType, + batch_size: int, + beam_size: int, + last_backpointer: Optional[torch.Tensor] = None, + ) -> ConstraintStateType: + """ + Copies the `state` . This method copies the data in `state` using `copy.deepcopy()`. If this + is not appropriate for your constraint, you will need to implement the copying yourself. + """ + new_state = [] + for i in range(batch_size): + batch_state = [] + for j in range(beam_size): + if last_backpointer is None: + # This is the first prediction, so the backpointer is 0 + backpointer = 0 + else: + backpointer = last_backpointer[i, j].item() + batch_state.append(copy.deepcopy(state[i][backpointer])) # type: ignore + new_state.append(batch_state) + return new_state + + def update_state( + self, + state: ConstraintStateType, + last_prediction: torch.Tensor, + last_backpointer: Optional[torch.Tensor] = None, + ) -> ConstraintStateType: + batch_size, beam_size = last_prediction.size() + new_state = self._copy_state(state, batch_size, beam_size, last_backpointer) + return self._update_state(new_state, last_prediction) + + @abstractmethod + def _update_state( + self, + state: ConstraintStateType, + last_prediction: torch.Tensor, + ) -> ConstraintStateType: + raise NotImplementedError + + +class RepeatedNGramBlockingConstraint(Constraint): + def __init__(self, ngram_size: int, **kwargs) -> None: + super().__init__(**kwargs) + self.ngram_size = ngram_size + + def init_state( + self, + batch_size: int, + ) -> ConstraintStateType: + return [[{"seen_ngrams": {}, "current_prefix": []}] for _ in range(batch_size)] + + def apply( + self, + state: ConstraintStateType, + class_log_probabilities: torch.Tensor, + ) -> torch.Tensor: + for i, batch in enumerate(state): + for j, beam in enumerate(batch): + current_prefix = tuple(beam["current_prefix"]) + seen_ngrams = beam["seen_ngrams"] + try: + disallowed_indices = seen_ngrams[current_prefix] + class_log_probabilities[i, j, disallowed_indices] = torch.finfo( + class_log_probabilities.dtype + ).min + except KeyError: + # We have not seen this prefix before, so there is no index + # that needs to be blocked + pass + return class_log_probabilities + + def _update_state( + self, + state: ConstraintStateType, + last_prediction: torch.Tensor, + ) -> ConstraintStateType: + for i, batch in enumerate(state): + for j, beam in enumerate(batch): + prediction = last_prediction[i, j].item() + prefix = beam["current_prefix"] + seen_ngrams = beam["seen_ngrams"] + + if len(prefix) == self.ngram_size - 1: + # This is a new ngram that we have to remember + if tuple(prefix) not in seen_ngrams: + seen_ngrams[tuple(prefix)] = [] + seen_ngrams[tuple(prefix)].append(prediction) + + # Create the new prefix, removing the oldest index if the prefix + # is too long + prefix.append(prediction) + if len(prefix) == self.ngram_size: + prefix.pop(0) + return state + + +class BeamSearch: + """ + Implements the beam search algorithm for decoding the most likely sequences. + + :param end_index: The index of the "stop" or "end" token in the vocabulary. Usually the EOS token ID. + + :param max_steps: The maximum number of decoding steps to take, i.e. the maximum length + of the predicted sequences. + + :param beam_size: The width of the beam used. + + :param per_node_beam_size: The maximum number of candidates to consider per node, at each step in the search. + If not given, this just defaults to `beam_size`. Setting this parameter + to a number smaller than `beam_size` may give better results, as it can introduce + more diversity into the search. See + [*Beam Search Strategies for Neural Machine Translation*, Freitag and Al-Onaizan, 2017] + (https://api.semanticscholar.org/CorpusID:2229477). + + :param sampler: An optional `Sampler` which is used to pick next candidate nodes and beams. + If not specified, `DeterministicSampler` will be used, which just takes the + `per_node_beam_size` most likely nodes and the `beam_size` most likely beams. + + Using the [`GumbelSampler`](#gumbelsampler), on the other hand, will give you + [Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039). + + :param min_steps: The minimum number of decoding steps to take, i.e. the minimum length of + the predicted sequences. This does not include the start or end tokens. If `None`, + no minimum is enforced. + + :param final_sequence_scorer: An optional `FinalSequenceScorer` which is used to score the final generated sequences. + The output from this module is what is returned by the `search` method. If not + specified, `SequenceLogProbabilityScorer` will be used, which scores the sequences + by the sum of the token log probabilities. + + :param constraints: An optional list of `Constraint`s which should be applied during beam search. If not + provided, no constraints will be enforced. + + """ + + def __init__( + self, + end_index: int, + *, + max_steps: int = 50, + beam_size: int = 10, + per_node_beam_size: Optional[int] = None, + sampler: Optional[Sampler] = None, + min_steps: Optional[int] = None, + final_sequence_scorer: Optional[FinalSequenceScorer] = None, + constraints: Optional[List[Constraint]] = None, + ) -> None: + if not max_steps > 0: + raise ValueError("max_steps must be positive") + if not beam_size > 0: + raise ValueError("beam_size must be positive") + if per_node_beam_size is not None and not per_node_beam_size > 0: + raise ValueError("per_node_beam_size must be positive") + if min_steps is not None: + if not min_steps >= 0: + raise ValueError("min_steps must be non-negative") + if not min_steps <= max_steps: + raise ValueError("min_steps must be less than or equal to max_steps") + + self._end_index = end_index + self.max_steps = max_steps + self.beam_size = beam_size + self.per_node_beam_size = per_node_beam_size or beam_size + self.sampler = sampler or DeterministicSampler() + self.min_steps = min_steps or 0 + self.final_sequence_scorer = final_sequence_scorer or SequenceLogProbabilityScorer() + self.constraints = constraints or [] + + @staticmethod + def _reconstruct_sequences(predictions, backpointers): + # Reconstruct the sequences. + # shape: [(batch_size, beam_size, 1)] + reconstructed_predictions = [predictions[-1].unsqueeze(2)] + + if not backpointers: + return reconstructed_predictions + + # shape: (batch_size, beam_size) + cur_backpointers = backpointers[-1] + + for timestep in range(len(predictions) - 2, 0, -1): + # shape: (batch_size, beam_size, 1) + cur_preds = predictions[timestep].gather(1, cur_backpointers).unsqueeze(2) + + reconstructed_predictions.append(cur_preds) + + # shape: (batch_size, beam_size) + cur_backpointers = backpointers[timestep - 1].gather(1, cur_backpointers) + + # shape: (batch_size, beam_size, 1) + final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2) + + reconstructed_predictions.append(final_preds) + + return reconstructed_predictions + + def search( + self, + start_predictions: torch.Tensor, + start_state: StateType, + step: StepFunctionType, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Given a starting state and a step function, apply beam search to find the + most likely target sequences. + + Returns a tuple of `(predictions, final_scores)`, where `predictions` + has shape `(batch_size, beam_size, max_steps)` and `final_scores` + has shape `(batch_size, beam_size)`. + + .. note:: + If your step function returns `-inf` for some log probabilities + (like if you're using a masked log-softmax) then some of the "best" + sequences returned may also have `-inf` log probability. Specifically + this happens when the beam size is smaller than the number of actions + with finite log probability (non-zero probability) returned by the step function. + Therefore if you're using a mask you may want to check the results from `search` + and potentially discard sequences with non-finite log probability. + + :param start_predictions: A tensor containing the initial predictions with shape `(batch_size,)`. + Usually the initial predictions are just the index of the "start" token + in the target vocabulary. + + :param start_state: The initial state passed to the `step` function. Each value of the state dict + should be a tensor of shape `(batch_size, *)`, where `*` means any other + number of dimensions. + + :param step: A function that is responsible for computing the next most likely tokens, + given the current state and the predictions from the last time step. + The function should accept two or three arguments: + + - a tensor of shape `(group_size,)` or representing the index of the predicted + tokens from the last time step, + - the current state, a `StateType`, and + - optionally, the timestep, an `int`. + + The `group_size` will be `batch_size * beam_size`, except in the initial + step, for which it will just be `batch_size`. + + The function is expected to return a tuple, where the first element + is a tensor of shape `(group_size, vocab_size)` containing + the log probabilities of the tokens for the next step, and the second + element is the updated state. The tensor in the state should have shape + `(group_size, *)`, where `*` means any other number of dimensions. + + """ + step_signature = signature(step) + if len(step_signature.parameters) < 3: + # If the step function we're given does not take the time step argument, wrap it + # in one that does. + old_step = cast(StepFunctionTypeNoTimestep, step) + + def new_step(last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], time_step: int): + del time_step + return old_step(last_predictions, state) + + return self._search(start_predictions, start_state, new_step) + else: + return self._search(start_predictions, start_state, cast(StepFunctionTypeWithTimestep, step)) + + def _search( + self, + start_predictions: torch.Tensor, + start_state: StateType, + step: StepFunctionTypeWithTimestep, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = start_predictions.size()[0] + + # List of (batch_size, beam_size) tensors. One for each time step. Does not + # include the start symbols, which are implicit. + predictions: List[torch.Tensor] = [] + + # List of (batch_size, beam_size) tensors. One for each time step. None for + # the first. Stores the index n for the parent prediction, i.e. + # predictions[t-1][i][n], that it came from. + backpointers: List[torch.Tensor] = [] + + constraint_states = [constraint.init_state(batch_size) for constraint in self.constraints] + + # Calculate the first timestep. This is done outside the main loop + # because we are going from a single decoder input (the output from the + # encoder) to the top `beam_size` decoder outputs. On the other hand, + # within the main loop we are going from the `beam_size` elements of the + # beam to `beam_size`^2 candidates from which we will select the top + # `beam_size` elements for the next iteration. + # shape: (batch_size, num_classes) + start_class_log_probabilities, state = step(start_predictions, start_state, 0) + + num_classes = start_class_log_probabilities.size()[1] + + # Make sure `per_node_beam_size` is not larger than `num_classes`. + if self.per_node_beam_size > num_classes: + raise ValueError( + f"Vocab size ({num_classes:d}) too small " + f"relative to per_node_beam_size ({self.per_node_beam_size:d}).\n" + f"Please decrease beam_size or per_node_beam_size." + ) + + sampler_state = self.sampler.init_state(start_class_log_probabilities, batch_size, num_classes) + + # Apply all constraints. + if self.constraints: + # shape: (batch_size, 1, num_classes) + expanded_start_class_log_probabilities = start_class_log_probabilities.unsqueeze(1) + for constraint, constraint_state in zip(self.constraints, constraint_states): + expanded_start_class_log_probabilities = constraint.apply( + constraint_state, expanded_start_class_log_probabilities + ) + start_class_log_probabilities = expanded_start_class_log_probabilities.squeeze(1) + + # Prevent selecting the end symbol if there is any min_steps constraint + if self.min_steps >= 1: + start_class_log_probabilities[:, self._end_index] = torch.finfo( + start_class_log_probabilities.dtype + ).min + + # Get the initial predicted classed and their log probabilities. + # shape: (batch_size, beam_size), (batch_size, beam_size) + ( + start_top_log_probabilities, + start_predicted_classes, + sampler_state, + ) = self.sampler.sample_beams(start_class_log_probabilities, self.beam_size, sampler_state) + + if self.beam_size == 1 and (start_predicted_classes == self._end_index).all(): + warnings.warn( + "Empty sequences predicted. You may want to increase the beam size or ensure " + "your step function is working properly.", + RuntimeWarning, + ) + return start_predicted_classes.unsqueeze(-1), start_top_log_probabilities + + # The log probabilities for the last time step. + # shape: (batch_size, beam_size) + last_log_probabilities = start_top_log_probabilities + + # shape: [(batch_size, beam_size)] + predictions.append(start_predicted_classes) + + # Log probability tensor that mandates that the end token is selected. + # shape: (batch_size * beam_size, num_classes) + log_probs_after_end = start_class_log_probabilities.new_full( + (batch_size * self.beam_size, num_classes), + torch.finfo(start_class_log_probabilities.dtype).min, + ) + log_probs_after_end[:, self._end_index] = 0.0 + + # Set the same state for each element in the beam. + self._update_initial_state(state, batch_size) + + for i, constraint in enumerate(self.constraints): + constraint_states[i] = constraint.update_state(constraint_states[i], start_predicted_classes) + + for timestep in range(self.max_steps - 1): + # shape: (batch_size * beam_size,) + last_predictions = predictions[-1].reshape(batch_size * self.beam_size) + + # If every predicted token from the last step is `self._end_index`, + # then we can stop early. + if (last_predictions == self._end_index).all(): + break + # Take a step. This get the predicted log probs of the next classes + # and updates the state. + # shape: (batch_size * beam_size, num_classes) + class_log_probabilities, state = step(last_predictions, state, timestep + 1) + + # Apply all constraints. + if self.constraints: + # shape: (batch_size, beam_size, num_classes) + reshaped_class_log_probabilities = class_log_probabilities.view(batch_size, self.beam_size, -1) + for constraint, constraint_state in zip(self.constraints, constraint_states): + reshaped_class_log_probabilities = constraint.apply( + constraint_state, reshaped_class_log_probabilities + ) + # shape: (batch_size * beam_size, num_classes) + class_log_probabilities = reshaped_class_log_probabilities.view(batch_size * self.beam_size, -1) + + # The `timestep`-th iteration of the for loop is generating the `timestep + 2`-th token + # of the sequence (because `timestep` is 0-indexed and we generated the first token + # before the for loop). Here we block the end index if the search is not allowed to + # terminate on this iteration. + if timestep + 2 <= self.min_steps: + class_log_probabilities[:, self._end_index] = torch.finfo(class_log_probabilities.dtype).min + + # shape: (batch_size * beam_size, num_classes) + last_predictions_expanded = last_predictions.unsqueeze(-1).expand( + batch_size * self.beam_size, num_classes + ) + + # Here we are finding any beams where we predicted the end token in + # the previous timestep and replacing the distribution with a + # one-hot distribution, forcing the beam to predict the end token + # this timestep as well. + # shape: (batch_size * beam_size, num_classes) + cleaned_log_probabilities = torch.where( + last_predictions_expanded == self._end_index, + log_probs_after_end, + class_log_probabilities, + ) + + # shape (both): (batch_size * beam_size, per_node_beam_size) + top_log_probabilities, predicted_classes, sampler_state = self.sampler.sample_nodes( + cleaned_log_probabilities, self.per_node_beam_size, sampler_state + ) + + # Here we expand the last log probabilities to (batch_size * beam_size, per_node_beam_size) + # so that we can add them to the current log probs for this timestep. + # This lets us maintain the log probability of each element on the beam. + # shape: (batch_size * beam_size, per_node_beam_size) + expanded_last_log_probabilities = ( + last_log_probabilities.unsqueeze(2) + .expand(batch_size, self.beam_size, self.per_node_beam_size) + .reshape(batch_size * self.beam_size, self.per_node_beam_size) + ) + + # shape: (batch_size * beam_size, per_node_beam_size) + summed_top_log_probabilities = top_log_probabilities + expanded_last_log_probabilities + + # shape: (batch_size, beam_size * per_node_beam_size) + reshaped_summed = summed_top_log_probabilities.reshape( + batch_size, self.beam_size * self.per_node_beam_size + ) + + # shape: (batch_size, beam_size * per_node_beam_size) + reshaped_predicted_classes = predicted_classes.reshape( + batch_size, self.beam_size * self.per_node_beam_size + ) + + # Keep only the top `beam_size` beam indices. + # shape (both): (batch_size, beam_size) + ( + restricted_beam_log_probs, + restricted_beam_indices, + sampler_state, + ) = self.sampler.sample_beams(reshaped_summed, self.beam_size, sampler_state) + + # Use the beam indices to extract the corresponding classes. + # shape: (batch_size, beam_size) + restricted_predicted_classes = reshaped_predicted_classes.gather(1, restricted_beam_indices) + + predictions.append(restricted_predicted_classes) + + # shape: (batch_size, beam_size) + last_log_probabilities = restricted_beam_log_probs + + # The beam indices come from a `beam_size * per_node_beam_size` dimension where the + # indices with a common ancestor are grouped together. Hence + # dividing by per_node_beam_size gives the ancestor. (Note that this is integer + # division as the tensor is a LongTensor.) + # shape: (batch_size, beam_size) + backpointer = torch.divide(restricted_beam_indices, self.per_node_beam_size, rounding_mode="trunc") + backpointers.append(backpointer) + + # Keep only the pieces of the state tensors corresponding to the + # ancestors created this iteration. + self._update_state(state, backpointer) + + for i, constraint in enumerate(self.constraints): + constraint_states[i] = constraint.update_state( + constraint_states[i], restricted_predicted_classes, last_backpointer=backpointer + ) + + # Warn about "-inf" log probabilities if not using any constraints (negligible + # log probabilities are expected when using constraints). + if not self.constraints and ( + not torch.isfinite(last_log_probabilities).all() + or (last_log_probabilities == torch.finfo(last_log_probabilities.dtype).min).any() + ): + warnings.warn( + "Negligible log probabilities encountered ('-inf' or equivalent). " + "Some final sequences may not make sense. " + "This can happen when the beam size is larger than the number of valid (non-zero " + "probability) transitions that the step function produces.", + RuntimeWarning, + ) + + reconstructed_predictions = self._reconstruct_sequences(predictions, backpointers) + + # shape: (batch_size, beam_size, max_steps) + all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2) + + # Calculate the final sequence scores + # shape: (batch_size, beam_size) + final_scores = self.final_sequence_scorer.score(all_predictions, last_log_probabilities, self._end_index) + + # Sort the sequences based on the final scores so the best scoring + # sequence is at index 0 + sorted_final_scores, sorted_indices = torch.sort(final_scores, dim=1, descending=True) + sorted_all_predictions = torch.gather( + all_predictions, 1, sorted_indices.unsqueeze(-1).expand_as(all_predictions) + ) + + return sorted_all_predictions, sorted_final_scores + + def _update_initial_state(self, state: StateType, batch_size: int): + """ + Expand tensors in a state dictionary from `(batch_size, *)` to `(batch_size * beam_size, *)`. + """ + for key, state_tensor in state.items(): + if state_tensor is None: + continue + # shape: (batch_size * beam_size, *) + _, *last_dims = state_tensor.size() + state[key] = ( + state_tensor.unsqueeze(1) + .expand(batch_size, self.beam_size, *last_dims) + .reshape(batch_size * self.beam_size, *last_dims) + ) + + def _update_state(self, state: StateType, backpointer: torch.Tensor): + batch_size = backpointer.size()[0] + + for key, state_tensor in state.items(): + if state_tensor is None: + continue + _, *last_dims = state_tensor.size() + # shape: (batch_size, beam_size, *) + expanded_backpointer = backpointer.view(batch_size, self.beam_size, *([1] * len(last_dims))).expand( + batch_size, self.beam_size, *last_dims + ) + # shape: (batch_size * beam_size, *) + state[key] = ( + state_tensor.reshape(batch_size, self.beam_size, *last_dims) + .gather(1, expanded_backpointer) + .reshape(batch_size * self.beam_size, *last_dims) + ) \ No newline at end of file diff --git a/3.test_cases/neuronx-distributed/olmo/olmo/config.py b/3.test_cases/neuronx-distributed/olmo/olmo/config.py new file mode 100644 index 00000000..9acabee6 --- /dev/null +++ b/3.test_cases/neuronx-distributed/olmo/olmo/config.py @@ -0,0 +1,1144 @@ +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +from glob import glob +from pathlib import Path +from typing import ( + Any, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +import numpy as np +import torch +from omegaconf import DictConfig, ListConfig +from omegaconf import OmegaConf as om +from omegaconf.errors import OmegaConfBaseException +from torch.distributed.fsdp import MixedPrecision, ShardingStrategy + +from .aliases import PathOrStr +from .exceptions import OLMoConfigurationError +from .util import StrEnum + +__all__ = [ + "ActivationType", + "ActivationCheckpointingStrategy", + "BlockType", + "LayerNormType", + "InitFnType", + "ModelConfig", + "OptimizerType", + "OptimizerConfig", + "SchedulerType", + "SchedulerConfig", + "DataConfig", + "InstanceFilterConfig", + "EvaluatorConfig", + "TokenizerConfig", + "TrainConfig", + "PaddingDirection", + "TruncationDirection", + "SpeedMonitorConfig", + "WandbConfig", + "CompilerConfig", + "WandbConfig", + "FSDPPrecision", + "FSDPWrapStrategy", + "FSDPConfig", + "CheckpointType", +] + +C = TypeVar("C", bound="BaseConfig") +D = TypeVar("D", bound="DictConfig|ListConfig") + + +class BaseConfig: + @classmethod + def _register_resolvers(cls, validate_paths: bool = True): + # Expands path globs into a list. + def path_glob(*paths) -> List[str]: + out = [] + for path in paths: + matches = sorted(glob(path)) + if not matches and validate_paths: + raise FileNotFoundError(f"{path} does not match any files or dirs") + out.extend(matches) + return out + + # Chooses the first path in the arguments that exists. + def path_choose(*paths) -> str: + from .util import is_url + + for path in paths: + if is_url(path) or Path(path).exists(): + return path + if validate_paths: + raise FileNotFoundError(", ".join(paths)) + else: + return "" + + # Finds the latest checkpoint in a folder. + def path_last_checkpoint(path) -> str: + from .util import find_latest_checkpoint + + latest_checkpoint = find_latest_checkpoint(path) + if latest_checkpoint is None: + if validate_paths: + raise FileNotFoundError(f"Could not find a latest checkpoint at {path}") + else: + return "" + else: + return str(latest_checkpoint) + + om.register_new_resolver("path.glob", path_glob, replace=True) + om.register_new_resolver("path.choose", path_choose, replace=True) + om.register_new_resolver("path.last_checkpoint", path_last_checkpoint, replace=True) + + @classmethod + def update_legacy_settings(cls, config: D) -> D: + """ + Update the legacy config settings whose schemas have undergone backwards-incompatible changes. + """ + return config + + @classmethod + def new(cls: Type[C], **kwargs) -> C: + cls._register_resolvers() + conf = om.structured(cls) + try: + if kwargs: + conf = om.merge(conf, kwargs) + return cast(C, om.to_object(conf)) + except OmegaConfBaseException as e: + raise OLMoConfigurationError(str(e)) + + @classmethod + def load( + cls: Type[C], + path: PathOrStr, + overrides: Optional[List[str]] = None, + key: Optional[str] = None, + validate_paths: bool = True, + ) -> C: + """Load from a YAML file.""" + cls._register_resolvers(validate_paths=validate_paths) + schema = om.structured(cls) + try: + raw = om.load(str(path)) + if key is not None: + raw = raw[key] # type: ignore + raw = cls.update_legacy_settings(raw) + conf = om.merge(schema, raw) + if overrides: + conf = om.merge(conf, om.from_dotlist(overrides)) + return cast(C, om.to_object(conf)) + except OmegaConfBaseException as e: + raise OLMoConfigurationError(str(e)) + + def save(self, path: PathOrStr) -> None: + """Save to a YAML file.""" + om.save(config=self, f=str(path)) + + def asdict(self, exclude: Optional[Iterable[str]] = None) -> Dict[str, Any]: + out = asdict(self) # type: ignore + if exclude is not None: + for name in exclude: + if name in out: + del out[name] + return out + + +class LayerNormType(StrEnum): + default = "default" + """ + The default LayerNorm implementation, equivalent to PyTorch's built-in version. + """ + + low_precision = "low_precision" + """ + A low-precision version of the default LayerNorm. + """ + + rms = "rms" + """ + An RMSNorm implementation. When using ``torch.compile`` this is + probably the fastest implementation. + """ + + +class ActivationType(StrEnum): + gelu = "gelu" + relu = "relu" + swiglu = "swiglu" + + +class BlockType(StrEnum): + sequential = "sequential" + + llama = "llama" + """ + A block similar to the sequential block with slightly different + implementations of operations like attention to imitate the behavior of Llama. + """ + + +class InitFnType(StrEnum): + mitchell = "mitchell" + """ + The strategy suggested to us by Mitchell Wortsman from UW. + This uses a truncated normal distribution with an adaptive standard deviation that depends + on the size of the weights as well as the depth of the layer. + """ + + normal = "normal" + """ + All weights are initialized from the same normal distribution. + """ + + kaiming_normal = "kaiming_normal" + """ + All weights are initialized with the Kaiming method from a normal distribution. + Note this currently won't work with FSDP. + """ + + fan_in = "fan_in" + """ + "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in`` + is the input dimensionality of the kernel. + """ + + full_megatron = "full_megatron" + """ + This is what metaseq calls "full megatron init". It is the init used for Llama 2. + """ + + +@dataclass +class ModelConfig(BaseConfig): + """ + OLMo (model) configuration. + """ + + # Note that the defaults for these attributes are equivalent to the base GPT2 model. + + d_model: int = 768 + """ + The hidden size of the model. + """ + + n_heads: int = 12 + """ + The number of self-attention heads. + """ + + n_kv_heads: Optional[int] = None + """ + The number of heads to use for keys and values. Defaults to `n_heads`. + Set this to ``None`` or ``n_heads`` for normal multi-head attention. + Set this to 1 for multi-query attention. + Set it to some in-between value for Llama2-style grouped query attention. + """ + + clip_qkv: Optional[float] = None + """ + Clip QKV to this value when set. + """ + + n_layers: int = 12 + """ + The number of layers/blocks. + """ + + mlp_ratio: int = 4 + """ + The ratio of the inner MLP dimensionality to ``d_model``. + This is only used when ``mlp_hidden_size`` is not set. + """ + + mlp_hidden_size: Optional[int] = None + """ + Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`. + """ + + activation_type: ActivationType = ActivationType.swiglu + """ + The activation function to use within the MLP layers. + """ + + block_type: BlockType = BlockType.sequential + """ + The transformer block implementation. + """ + + block_group_size: int = 1 + """ + The number of blocks to group together into a single parent block. + This has no affect on the number of parameters in the model and is only used to wrap groups + of blocks together with a single FSDP wrapper during training. + """ + + alibi: bool = False + """ + If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``. + """ + + alibi_bias_max: float = 8.0 + """ + Maximum absolute value of ALiBi bias. + """ + + rope: bool = False + """ + Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``. + """ + + rope_full_precision: bool = True + """ + If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise, + apply RoPE at the precision of the input. + """ + + flash_attention: bool = False + """ + If ``True``, use ``FlashAttention``. + """ + + attention_dropout: float = 0.1 + """ + The dropout probability within the attention modules. + """ + + multi_query_attention: Optional[bool] = None + """ + Deprecated. Use n_kv_heads instead. + """ + + attention_layer_norm: bool = False + """ + Apply layer norm to the keys and queries within the attention mechanism. + This can help stabilize training. + """ + + residual_dropout: float = 0.1 + """ + The dropout probability for the MLP and attention output within each block. + """ + + embedding_dropout: float = 0.1 + """ + The dropout probability for embeddings. + """ + + layer_norm_type: LayerNormType = LayerNormType.default + """ + The layernorm implementation to use. + """ + + layer_norm_with_affine: bool = True + """ + Whether to include bias and weight parameters for the layer norms. + This only affects layer norms that are immediately followed by a linear layer in the forward pass, + so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine` + to ``False``. + """ + + attention_layer_norm_with_affine: bool = True + """ + Toggle affine transform for the QK norms. + """ + + max_sequence_length: int = 1024 + """ + The maximum input sequence length supported by the model. + """ + + include_bias: bool = True + """ + Whether or not to include bias parameters in linear layers. + In PaLM, they got rid of all bias terms because they found that large + models tend to have near 0 bias terms anyway. + """ + + bias_for_layer_norm: Optional[bool] = None + """ + Whether or not to include bias parameters in layer norm. + This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in + layer norm. + When this is None (the default), it inherits the setting from include_bias. + """ + + scale_logits: bool = False + """ + If ``True``, scale the output logits by ``1 / sqrt(d_model)``. + """ + + vocab_size: int = 50257 + """ + Vocabulary size of the model. + """ + + embedding_size: Optional[int] = 50304 + """ + The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default + to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the + next multiple of 128 that's greater than ``vocab_size`` can improve throughput + substantially. + """ + + weight_tying: bool = True + """ + Whether to tie output linear weights to the input embedding. + """ + + eos_token_id: int = 50256 + """ + The ID of the end-of-sentence special token. + """ + + pad_token_id: int = 50256 + """ + The ID of the token to use for padding. Defaults to the ID of the EOS token. + """ + + init_device: Optional[str] = None + """ + The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta". + """ + + init_fn: InitFnType = InitFnType.normal + """ + The weight initialization strategy. + """ + + init_std: float = 0.02 + """ + The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such + as "normal". + """ + + init_cutoff_factor: Optional[float] = None + """ + A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such + as "normal". Setting this to None means values are not cutoff. + """ + + precision: Optional[str] = None + """ + Precision used to train/evaluate with. You shouldn't set this directly. + See :data:`TrainConfig.precision` instead. + """ + + @property + def effective_n_kv_heads(self) -> int: + if self.n_kv_heads is None: + if self.multi_query_attention is True: + return 1 + else: + return self.n_heads + else: + if self.multi_query_attention is None: + return self.n_kv_heads + if self.multi_query_attention: + n_kv_heads_should_be = 1 + else: + n_kv_heads_should_be = self.n_heads + if self.n_kv_heads == n_kv_heads_should_be: + return n_kv_heads_should_be + else: + raise OLMoConfigurationError( + "You can't set `multi_query_attention` and `n_kv_heads` at the same time." + ) + + +class OptimizerType(StrEnum): + lionw = "lionw" + adamw = "adamw" + + +@dataclass +class OptimizerConfig(BaseConfig): + name: OptimizerType = OptimizerType.lionw + learning_rate: float = 1.0e-4 + weight_decay: float = 0.01 + betas: Tuple[float, float] = (0.9, 0.95) + eps: float = 1e-5 + + no_decay_norm_and_bias: Optional[bool] = None + """ + Deprecated. Use ``decay_norm_and_bias`` and ``decay_embeddings`` instead. + """ + + decay_norm_and_bias: bool = False + decay_embeddings: bool = False + metrics_log_interval: Optional[int] = None + """ + The interval with which to collect and log detailed parameter-specific metrics. + This only applies when logging to W&B, since these metrics won't be logged to the console. + If not set, defaults to the wandb `log_interval`. + """ + + def __post_init__(self): + self.betas = tuple(self.betas) # type: ignore[assignment] + + @classmethod + def update_legacy_settings(cls, config: D) -> D: + new_config = config.copy() + if om.is_dict(new_config): + assert isinstance(new_config, DictConfig) + + if hasattr(new_config, "name") and new_config.name == "decoupled_lionw": + new_config.name = "lionw" + if hasattr(new_config, "eps"): + del new_config.eps + + return new_config + + +class SchedulerType(StrEnum): + cosine_with_warmup = "cosine_with_warmup" + linear_with_warmup = "linear_with_warmup" + inverse_sqrt_with_warmup = "inverse_sqrt_with_warmup" + max_scheduler = "max_scheduler" + constant = "constant" + + +class SchedulerUnits(StrEnum): + steps = "steps" + tokens = "tokens" + + +@dataclass +class SchedulerConfig(BaseConfig): + name: SchedulerType = SchedulerType.cosine_with_warmup + units: SchedulerUnits = SchedulerUnits.steps + t_warmup: Union[int, float] = 100 + t_max: Optional[Union[int, float]] = None + alpha_f: float = 0.1 + + grad_clip_warmup_steps: Optional[Union[int, float]] = None + """ + The warmup period for which the max grad norm (or norm ratio) will be set to its + warmup value of `max_grad_norm * grad_clip_warmup_factor`. + """ + + grad_clip_warmup_factor: Optional[float] = None + """ + The ratio of the max allowed gradient norm (or norm ratio) for clipping during the warmup period + vs after the warmup period. + """ + + warmup_min_lr: Optional[float] = None + """ + The starting LR during the warmup period. If not set this defaults to 10% of + the target LR. + """ + + +class PaddingDirection(StrEnum): + right = "right" + left = "left" + + +@dataclass +class InstanceFilterConfig(BaseConfig): + repetition_max_period: int = 13 + repetition_min_period: int = 1 + repetition_max_count: int = 32 + + +@dataclass +class DataConfig(BaseConfig): + paths: Optional[List[str]] = None + memmap_dtype: str = "uint16" + datasets: Optional[Dict[str, List[str]]] = None + label_mask_paths: Optional[List[str]] = None + pad_direction: PaddingDirection = PaddingDirection.right + generate_attention_mask: bool = False + num_workers: int = 0 + drop_last: bool = False + pin_memory: bool = False + prefetch_factor: Optional[int] = None + persistent_workers: bool = False + timeout: int = 0 + seed: Optional[int] = None + instance_filter: Optional[InstanceFilterConfig] = None + + @property + def effective_memmap_dtype(self): + if self.memmap_dtype == "uint8": + return np.uint8 + if self.memmap_dtype == "uint16": + return np.uint16 + elif self.memmap_dtype == "uint32": + return np.uint32 + elif self.memmap_dtype == "uint64": + return np.uint64 + # default to uint16 if not set + return np.uint16 + + +class EvaluatorType(StrEnum): + downstream = "downstream" + lm = "lm" + + +@dataclass +class EvaluatorConfig(BaseConfig): + label: str + type: EvaluatorType = EvaluatorType.lm + data: DataConfig = field(default_factory=DataConfig) + device_eval_batch_size: Optional[int] = None + subset_num_batches: Optional[int] = None + + +class TruncationDirection(StrEnum): + right = "right" + left = "left" + + +@dataclass +class TokenizerConfig(BaseConfig): + identifier: str = "gpt2" + truncate_direction: TruncationDirection = TruncationDirection.right + + +@dataclass +class WandbConfig(BaseConfig): + project: Optional[str] = None + entity: Optional[str] = "ai2-llm" + group: Optional[str] = None + name: Optional[str] = None + tags: Optional[List[str]] = field(default_factory=lambda: ["watching"]) + log_artifacts: bool = False + rank_zero_only: bool = True + log_interval: int = 1 + + +@dataclass +class SpeedMonitorConfig(BaseConfig): + window_size: int = 100 + gpu_flops_available: Optional[Union[float, int]] = None + + +@dataclass +class CompilerConfig(BaseConfig): + mode: Optional[str] = None + """ + The mode to compile the model in. At the moment this can be "default", + "reduce-overhead" (useful for smaller models/batches), or "max-autotune" + (the fastest for larger models, but takes a long time to compile). + """ + + fullgraph: bool = False + """ + Whether it is OK to break model into several subgraphs when compiling. + Note that this is not compatible with FSDP. + """ + + backend: str = "inductor" + """ + The backend to use. + """ + + +class FSDPWrapStrategy(StrEnum): + by_block = "by_block" + """ + Wrap each OLMo block with its own FSDP instance. + """ + + by_block_and_size = "by_block_and_size" + """ + Like 'by_block' but `wte` and `ff_out` will be wrapped separately as well. + """ + + by_block_group = "by_block_group" + """ + Wrap each block group together into its own FSDP instance. + This requires :attr:`~ModelConfig.block_group_size` to be bigger than 1. + """ + + by_block_group_and_size = "by_block_group_and_size" + """ + Like 'by_block_group' but `wte` and `ff_out` will be wrapped separately as well. + """ + + size_based = "size_based" + """ + Used PyTorch's default size-based auto wrap policy. + """ + + one_in_two = "one_in_two" + one_in_three = "one_in_three" + one_in_four = "one_in_four" + one_in_five = "one_in_five" + + +class FSDPPrecision(StrEnum): + pure = "pure" + """ + Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, ``reduce_dtype``, + and ``buffer_dtype`` all set to the autocast precision data type. + """ + + mixed = "mixed" + """ + Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, and ``buffer_dtype`` + set to the autocast precision data type, while ``reduce_dtype`` is set to fp32. + """ + + +@dataclass +class FSDPConfig(BaseConfig): + use_orig_params: bool = True + """ + This must be ``True`` if using ``compile`` or you want to track the parameter norm during training. + """ + + sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD + + wrapping_strategy: Optional[FSDPWrapStrategy] = None + """ + The wrapping strategy to use. If ``None``, the default, the model is wrapped with a single top-level + FSDP instance. + """ + + precision: FSDPPrecision = FSDPPrecision.pure + + hybrid_sharding_num_model_replicas: Optional[int] = None + """ + The number of model instances, when using a hybrid sharding strategy. + If not ``None``, this must divide the total number of nodes. If ``None``, the default, + a model instance is used per node (as determined by ``get_world_size() // get_local_world_size()``). + PyTorch's default HSDP behavior matches this default behavior. + """ + + +class CheckpointType(StrEnum): + sharded = "sharded" + unsharded = "unsharded" + sharded_ephemeral = "sharded_ephemeral" + + +class ShardedCheckpointerType(StrEnum): + torch_new = "torch_new" + torch_legacy = "torch_legacy" + local = "local" + olmo_core = "olmo_core" + + +class ActivationCheckpointingStrategy(StrEnum): + whole_layer = "whole_layer" + """ + Checkpoint every transformer layer. + """ + + one_in_two = "one_in_two" + """ + Checkpoint one in two transformer layers. + """ + + one_in_three = "one_in_three" + """ + Checkpoint one in three transformer layers. + """ + + one_in_four = "one_in_four" + """ + Checkpoint one in four transformer layers. + """ + + two_in_three = "two_in_three" + """ + Checkpoint two out of every three transformer layers. + """ + + three_in_four = "three_in_four" + """ + Checkpoint three out of four of every transformer layers. + """ + + fine_grained = "fine_grained" + """ + Focus checkpointing on where it is cheap to recompute and saves most memory. + """ + + +@dataclass +class TrainConfig(BaseConfig): + """ + OLMo training configuration. + """ + + run_name: Optional[str] = None + """ + The name of the run. + """ + + seed: int = 6198 + """ + Used to seed all initial RNG states. + """ + + epoch: Optional[int] = None + """ + Increment this when starting a new epoch. + """ + + dry_run: bool = False + """ + If ``True``, don't actually train. + """ + + model: ModelConfig = field(default_factory=ModelConfig) + """ + OLMo Model configuration. + """ + + optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) + """ + Optimizer configuration. + """ + + scheduler: SchedulerConfig = field(default_factory=SchedulerConfig) + """ + Learning rate scheduler configuration. + """ + + data: DataConfig = field(default_factory=DataConfig) + """ + Training data configuration. + """ + + restore_dataloader: bool = True + """ + When restarting, restore the data loader to where it left off. + If you restarting in order to train on a different dataset, set this to ``False``. + """ + + fast_forward_batches: Optional[int] = None + """ + When restarting, use this to fast-forward the dataloader beyond the last checkpoint. + This can be useful when restarting due to a loss spike in order to skip the data that + corresponded to the spike. + """ + + evaluators: List[EvaluatorConfig] = field(default_factory=list) + """ + Evaluation configurations. + """ + + eval_interval: int = 1000 + """ + How often (in terms of batches) to run evaluations. + """ + + tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig) + """ + Tokenizer configuration. + """ + + save_folder: str = "./" + """ + The directory to save checkpoints to. + """ + + remote_save_folder: Optional[str] = None + """ + A folder in a cloud bucket to upload saved checkpoints to. + """ + + canceled_check_interval: int = 50 + """ + How often (in batches) to check if the run has been canceled or reached its time limit. + """ + + save_interval: int = 1000 + """ + How often (in terms of steps) to save sharded training state checkpoints. + """ + + save_interval_unsharded: Optional[int] = None + """ + How often (if at all) to save unsharded training state checkpoint. + For large models it can be costly to save these, so it usually makes sense to save + these less often than regular (sharded) training checkpoints. + """ + + save_interval_ephemeral: Optional[int] = None + """ + How often (if at all) to save ephemeral sharded checkpoints. These checkpoints are the same + as those saved every `save_interval` except that at most only the most recent one of these is kept. + This is useful when you want to checkpoint often for restarts in case of failures, but don't + want to keep the majority of these checkpoints. + + For example, suppose you want to keep your checkpoints at every 1000 steps, but you also want to save + a temporary checkpoint every 100 steps in case your job fails. In that case you would + set `save_interval=1000` and `save_interval_ephemeral=100`. + """ + + save_num_checkpoints_to_keep: int = -1 + """ + How many sharded checkpoints to keep. + """ + + save_num_unsharded_checkpoints_to_keep: int = -1 + """ + How many unsharded checkpoints to keep. + """ + + save_overwrite: bool = False + """ + If ``True``, overwrite any conflicting checkpoint files. + """ + + force_save_unsharded: bool = False + """ + Save an unsharded checkpoint before training (even during a dry run). + Use this option with `--load-path={PATH}` and `--dry_run` to convert a sharded + checkpoint into an unsharded checkpoint. + """ + + no_pre_train_checkpoint: bool = False + """ + Skip saving pre-train checkpoint. + """ + + load_path: Optional[str] = None + """ + The path to a training checkpoint to restore/resume from. + + Note that you can make use of the "path.last_checkpoint" Omegaconfig YAML resolver here, which takes + a local or remote directory and resolves to the latest checkpoint (sharded or unsharded) in that directory. + For example, + + ```bash + --load_path='${path.last_checkpoint:s3://ai2-llm/checkpoints/7b/v1_5-mix-run-001}' + ``` + """ + + load_path_sharded_checkpointer: Optional[ShardedCheckpointerType] = None + """ + The sharded checkpointer type to use to load the initial checkpoint from ``load_path``. + """ + + reset_optimizer_state: bool = False + """ + When this is set, we restore the model from a checkpoint (if given), but we leave the optimizer uninitialized. + We also set a new learning rate schedule that does a new warmup, such that it intercepts the original learning + curve (according to the current learning rate schedule settings), and continues from there. + """ + + reset_trainer_state: bool = False + """ + When this is set we don't restore the trainer state from a checkpoint. + """ + + sharded_checkpointer: ShardedCheckpointerType = ShardedCheckpointerType.torch_legacy + """ + The name of the sharded checkpointer to use to save (sharded) checkpoints throughout training. + """ + + new_style_checkpoints: Optional[bool] = None + """ + Deprecated. Use ``sharded_checkpointer`` instead. + """ + + max_duration: Union[int, str] = 10000 + """ + How long to train for. + + If specified without a unit (the default), the units are assumed to be steps. + You can also specify this in terms of tokens, for example: `max_duration="2e12T"` means train until + 2 trillion tokens. + """ + + global_train_batch_size: int = 512 + """ + The effective global batch size. + """ + + device_train_batch_size: Optional[int] = None # calculated automatically + """ + Don't set this manually. This will be set to ``global_train_batch_size // world_size``. + """ + + device_train_microbatch_size: int = 16 + """ + The number of instances passed to the model in a single forward-backward pass. You should set + this as large as you can based on available GPU memory. + """ + + device_eval_batch_size: int = 16 + """ + The number of evaluation instances passed to the model in a single forward pass on each device. + """ + + eval_subset_num_batches: int = -1 + """ + The number of batches to use for downstream evaluation from each dataset. + """ + + eval_on_load: bool = False + """ + When resuming from a checkpoint, run the evaluation loop right away. + """ + + device_train_grad_accum: Optional[int] = None # calculated automatically + """ + Don't set this manually. This will be set to ``device_train_batch_size // device_train_microbatch_size``. + """ + + max_grad_norm: Optional[float] = None + """ + Clip gradient norms to this value if set. + """ + + max_grad_norm_ratio: Optional[float] = None + """ + If set, gradient norms will be clipped to `max_grad_norm_ratio * exp_avg(norm(grad))`. + This takes priority over `max_grad_norm` when set. + """ + + precision: Optional[str] = None + """ + Precision to train with (e.g. "amp_bf16", "amp_fp16", or "fp32"). + """ + + wandb: Optional[WandbConfig] = None + """ + Weights & Biases configuration. + """ + + speed_monitor: SpeedMonitorConfig = field(default_factory=SpeedMonitorConfig) + """ + Speed monitor configuration. + """ + + console_log_interval: int = 1 + """ + How often to log to the console. + """ + + gen1_gc_interval: Optional[int] = 1 + """ + How often (in steps) to run generation 1 garbage collection. + Set to ``None`` to use automatic garbage collection (i.e. we don't mess with it). + """ + + compile: Optional[CompilerConfig] = None + """ + Settings for compiling the model with ``torch.compile()``. + """ + + fsdp: FSDPConfig = field(default_factory=FSDPConfig) + """ + Fully sharded data parallel settings. + """ + + softmax_auxiliary_loss: bool = False + """ + If ``True``, we add the auxiliary loss function from PaLM that encourages the softmax + normalizing term to be close to 0. + """ + + time_limit: Optional[float] = None + """ + The maximum amount of time to train for before saving a checkpoint and ending early. + """ + + extra_steps_after_cancel: int = 10 + """ + Under certain conditions when a run is canceled we train for a few extra steps after saving + the final checkpoint so that when the run is restarted from the latest checkpoint we have some + overlap in metrics. + """ + + early_stopping_factor: Optional[float] = None + + save_data_indices: bool = True + """ + Save training data indices from each batch for each worker. + """ + + python_profiling: bool = False + """ + Whether to run the Python profiler on batches 6, 7, and 8. + """ + + torch_profiling: bool = False + """ + Whether to run the PyTorch profiler on batches 6, 7, and 8. + """ + + stop_at: Optional[int] = None + """ + Stop at a specific step. + """ + + stop_after: Optional[int] = None + """ + Stop after a specific number of steps. + """ + + activation_checkpointing: Optional[ActivationCheckpointingStrategy] = None + """ + The activation checkpointing strategy to use. + """ + + fused_loss: Optional[bool] = None + """ + Whether to use the fused CE loss function from `flash-attn`. + """ + + @property + def autocast_precision(self) -> torch.dtype: + if self.precision == "amp_bf16": + return torch.bfloat16 + elif self.precision == "amp_fp16": + return torch.float16 + elif self.precision == "fp32": + return torch.float32 + else: + raise ValueError(f"Unexpected precision type '{self.precision}'") + + @property + def fsdp_precision(self) -> MixedPrecision: + if self.fsdp.precision == FSDPPrecision.pure: + return MixedPrecision( + param_dtype=self.autocast_precision, + reduce_dtype=self.autocast_precision, + buffer_dtype=self.autocast_precision, + ) + elif self.fsdp.precision == FSDPPrecision.mixed: + return MixedPrecision( + param_dtype=self.autocast_precision, + reduce_dtype=torch.float32, + buffer_dtype=self.autocast_precision, + ) + else: + raise NotImplementedError(f"{self.fsdp.precision}") + + @classmethod + def update_legacy_settings(cls, config: D) -> D: + new_config = config.copy() + if om.is_dict(new_config): + assert isinstance(new_config, DictConfig) + + if hasattr(new_config, "activation_checkpointing"): + if new_config.activation_checkpointing is False: + new_config.activation_checkpointing = None + if new_config.activation_checkpointing is True: + new_config.activation_checkpointing = ActivationCheckpointingStrategy.whole_layer + + if hasattr(new_config, "optimizer"): + new_config.optimizer = OptimizerConfig.update_legacy_settings(new_config.optimizer) + + return new_config \ No newline at end of file diff --git a/3.test_cases/neuronx-distributed/olmo/olmo/datasets.py b/3.test_cases/neuronx-distributed/olmo/olmo/datasets.py new file mode 100644 index 00000000..95d1f28c --- /dev/null +++ b/3.test_cases/neuronx-distributed/olmo/olmo/datasets.py @@ -0,0 +1,135 @@ +import pickle + +import torch +from torch.utils.data import Dataset + +class AdditionDataset(Dataset): + """ + Creates n-digit addition problems. For example, if n=2, then an example + addition problem would be to add 85 + 50 = 135. This problem would be + represented as the following string for the GPT: + + "8550531" + + This is because: + - we are discarding the + and =, which are not necessary. We just encode the digits + of the input numbers concatenated together. + - the result 135 is encoded backwards to make the addition easier to learn for the + GPT model, because of how the addition algorithm works. + + As one more example, the problem 6 + 39 = 45 would be encoded as: + + "0639054" + + where you will notice that we are padding with zeros to make sure that we always + produce strings of the exact same size: n + n + (n + 1). When n=2, this is 7. + At test time, we will feed in an addition problem by giving the first 2n digits, + and hoping that the GPT model completes the sequence with the next (n+1) digits + correctly. + """ + + def __init__(self, split, ndigit=2): + self.split = split # train/test + + # split up all addition problems into either training data or test data + self.ndigit = ndigit + assert ndigit <= 3, "the lines below would be very memory inefficient, in future maybe refactor to support" + num = (10**ndigit)**2 # total number of possible addition problems with ndigit numbers + rng = torch.Generator() + rng.manual_seed(1337) + perm = torch.randperm(num, generator=rng) + num_test = min(int(num*0.2), 500) # 20% of the whole dataset, or only up to 500 + self.ixes = perm[:num_test] if split == 'test' else perm[num_test:] + + def get_vocab_size(self): + return 10 # digits 0..9 + + def get_block_size(self): + # a,b,a+b, and +1 due to potential carry overflow, + # but then also -1 because very last digit doesn't ever plug back + # as there is no explicit token to predict, it is implied + return 3*self.ndigit + 1 - 1 + + def __len__(self): + return self.ixes.nelement() + + def __getitem__(self, idx): + ndigit = self.ndigit + # given a problem index idx, first recover the associated a + b + idx = self.ixes[idx].item() + nd = 10**ndigit + a = idx // nd + b = idx % nd + # calculate the "label" of the addition problem a + b + c = a + b + # encode the digits of a, b, c into strings + astr = f'%0{ndigit}d' % a + bstr = f'%0{ndigit}d' % b + cstr = (f'%0{ndigit+1}d' % c)[::-1] # reverse c to make addition easier + render = astr + bstr + cstr + dix = [int(s) for s in render] # convert each character to its token index + # x will be input to GPT and y will be the associated expected outputs + x = torch.tensor(dix[:-1], dtype=torch.long) + y = torch.tensor(dix[1:], dtype=torch.long) # predict the next token in the sequence + y[:ndigit*2-1] = -1 # we will only train in the output locations. -1 will mask loss to zero + return x, y + +class SortDataset(Dataset): + """ + Dataset for the Sort problem. E.g. for problem length 6: + Input: 0 0 2 1 0 1 -> Output: 0 0 0 1 1 2 + Which will feed into the transformer concatenated as: + input: 0 0 2 1 0 1 0 0 0 1 1 + output: I I I I I 0 0 0 1 1 2 + where I is "ignore", as the transformer is reading the input sequence + """ + + def __init__(self, split, length=6, num_digits=3): + assert split in {'train', 'test'} + self.split = split + self.length = length + self.num_digits = num_digits + + def __len__(self): + return 10000 # ... + + def get_vocab_size(self): + return self.num_digits + + def get_block_size(self): + # the length of the sequence that will feed into transformer, + # containing concatenated input and the output, but -1 because + # the transformer starts making predictions at the last input element + return self.length * 2 - 1 + + def __getitem__(self, idx): + + # use rejection sampling to generate an input example from the desired split + while True: + # generate some random integers + inp = torch.randint(self.num_digits, size=(self.length,), dtype=torch.long) + # half of the time let's try to boost the number of examples that + # have a large number of repeats, as this is what the model seems to struggle + # with later in training, and they are kind of rate + if torch.rand(1).item() < 0.5: + if inp.unique().nelement() > self.length // 2: + # too many unqiue digits, re-sample + continue + # figure out if this generated example is train or test based on its hash + h = hash(pickle.dumps(inp.tolist())) + inp_split = 'test' if h % 4 == 0 else 'train' # designate 25% of examples as test + if inp_split == self.split: + break # ok + + # solve the task: i.e. sort + sol = torch.sort(inp)[0] + + # concatenate the problem specification and the solution + cat = torch.cat((inp, sol), dim=0) + + # the inputs to the transformer will be the offset sequence + x = cat[:-1].clone() + y = cat[1:].clone() + # we only want to predict at output locations, mask out the loss at the input locations + y[:self.length-1] = -1 + return x, y diff --git a/3.test_cases/neuronx-distributed/olmo/olmo/exceptions.py b/3.test_cases/neuronx-distributed/olmo/olmo/exceptions.py new file mode 100644 index 00000000..9670b4b2 --- /dev/null +++ b/3.test_cases/neuronx-distributed/olmo/olmo/exceptions.py @@ -0,0 +1,50 @@ +__all__ = [ + "OLMoError", + "OLMoConfigurationError", + "OLMoCliError", + "OLMoEnvironmentError", + "OLMoNetworkError", + "OLMoCheckpointError", +] + + +class OLMoError(Exception): + """ + Base class for all custom OLMo exceptions. + """ + + +class OLMoConfigurationError(OLMoError): + """ + An error with a configuration file. + """ + + +class OLMoCliError(OLMoError): + """ + An error from incorrect CLI usage. + """ + + +class OLMoEnvironmentError(OLMoError): + """ + An error from incorrect environment variables. + """ + + +class OLMoNetworkError(OLMoError): + """ + An error with a network request. + """ + + +class OLMoCheckpointError(OLMoError): + """ + An error occurred reading or writing from a checkpoint. + """ + + +class OLMoThreadError(Exception): + """ + Raised when a thread fails. + """ \ No newline at end of file diff --git a/3.test_cases/neuronx-distributed/olmo/olmo/initialization.py b/3.test_cases/neuronx-distributed/olmo/olmo/initialization.py new file mode 100644 index 00000000..9e4d9543 --- /dev/null +++ b/3.test_cases/neuronx-distributed/olmo/olmo/initialization.py @@ -0,0 +1,22 @@ +from typing import Optional, Union + +import torch.nn as nn + +__all__ = ["init_normal"] + + +def init_normal( + module: Union[nn.Linear, nn.Embedding], + std: float, + init_cutoff_factor: Optional[float] = None, +): + # weights + if init_cutoff_factor is not None: + cutoff_value = init_cutoff_factor * std + nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value) + else: + nn.init.normal_(module.weight, mean=0.0, std=std) + + # biases + if isinstance(module, nn.Linear) and module.bias is not None: + nn.init.zeros_(module.bias) \ No newline at end of file diff --git a/3.test_cases/neuronx-distributed/olmo/olmo/model.py b/3.test_cases/neuronx-distributed/olmo/olmo/model.py new file mode 100644 index 00000000..f9e3c011 --- /dev/null +++ b/3.test_cases/neuronx-distributed/olmo/olmo/model.py @@ -0,0 +1,1689 @@ +""" +Adapted from +[MosaiclML](https://github.com/mosaicml/examples.git) and +[minGPT](https://github.com/karpathy/minGPT.git) +""" + +from __future__ import annotations + +import logging +import math +import sys +from abc import abstractmethod +from collections import defaultdict +from functools import partial +from typing import ( + Callable, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Sequence, + Set, + Tuple, + cast, +) + +import torch +import torch.backends.cuda +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum + +from .aliases import PathOrStr +from .beam_search import BeamSearch, Constraint, FinalSequenceScorer, Sampler +from .config import ( + ActivationCheckpointingStrategy, + ActivationType, + BlockType, + CheckpointType, + FSDPWrapStrategy, + InitFnType, + LayerNormType, + ModelConfig, +) +from .exceptions import OLMoConfigurationError +from .initialization import init_normal +from .torch_util import ensure_finite_ + +if sys.version_info.minor > 8: + from collections.abc import MutableMapping +elif sys.version_info.minor == 8: + from typing import MutableMapping +else: + raise SystemExit("This script supports Python 3.8 or higher") + +__all__ = [ + "LayerNormBase", + "LayerNorm", + "RMSLayerNorm", + "RotaryEmbedding", + "Activation", + "GELU", + "ReLU", + "SwiGLU", + "OLMoBlock", + "OLMoSequentialBlock", + "OLMo", + "OLMoOutput", + "OLMoGenerateOutput", +] + + +log = logging.getLogger(__name__) + + +def activation_checkpoint_function(cfg: ModelConfig): + preserve_rng_state = ( + (cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and (cfg.residual_dropout == 0.0) + ) + from torch.utils.checkpoint import checkpoint + + return partial( + checkpoint, + preserve_rng_state=preserve_rng_state, + use_reentrant=False, + ) + + +def should_checkpoint_block(strategy: Optional[ActivationCheckpointingStrategy], block_idx: int) -> bool: + if strategy is None: + return False + elif ( + (strategy == ActivationCheckpointingStrategy.whole_layer) + or (strategy == ActivationCheckpointingStrategy.one_in_two and block_idx % 2 == 0) + or (strategy == ActivationCheckpointingStrategy.one_in_three and block_idx % 3 == 0) + or (strategy == ActivationCheckpointingStrategy.one_in_four and block_idx % 4 == 0) + or (strategy == ActivationCheckpointingStrategy.two_in_three and block_idx % 3 != 0) + or (strategy == ActivationCheckpointingStrategy.three_in_four and block_idx % 4 != 0) + ): + return True + else: + return False + + +class BufferCache(dict, MutableMapping[str, torch.Tensor]): + """ + Cache for attention biases and other things that would normally be stored as buffers. + We avoid using buffers because we've run into various issues doing so with FSDP. + In general it appears the way FSDP handles buffers is not well-defined. + It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid + since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into + NaNs when they're synchronized due to casting or some other issue. + """ + + +def _non_meta_init_device(config: ModelConfig) -> torch.device: + if config.init_device is not None and config.init_device != "meta": + return torch.device(config.init_device) + else: + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Dropout(nn.Dropout): + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.p == 0.0: + return input + else: + return F.dropout(input, self.p, self.training, self.inplace) + + +class LayerNormBase(nn.Module): + def __init__( + self, + config: ModelConfig, + *, + size: Optional[int] = None, + elementwise_affine: Optional[bool] = True, + eps: float = 1e-05, + ): + super().__init__() + self.config = config + self.eps = eps + self.normalized_shape = (size or config.d_model,) + if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine): + self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device)) + use_bias = self.config.bias_for_layer_norm + if use_bias is None: + use_bias = self.config.include_bias + if use_bias: + self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device)) + else: + self.register_parameter("bias", None) + else: + self.register_parameter("bias", None) + self.register_parameter("weight", None) + + @abstractmethod + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @classmethod + def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> LayerNormBase: + if config.layer_norm_type == LayerNormType.default: + return LayerNorm(config, size=size, low_precision=False, **kwargs) + elif config.layer_norm_type == LayerNormType.low_precision: + return LayerNorm(config, size=size, low_precision=True, **kwargs) + elif config.layer_norm_type == LayerNormType.rms: + return RMSLayerNorm(config, size=size, **kwargs) + else: + raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'") + + def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function + # `is_autocast_cpu_enabled()` for CPU autocast. + # See https://github.com/pytorch/pytorch/issues/110966. + if tensor.device.type == "cuda" and torch.is_autocast_enabled(): + return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype()) + elif tensor.device.type == "cpu" and torch.is_autocast_cpu_enabled(): + return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype()) + else: + return tensor + + def reset_parameters(self): + if self.weight is not None: + torch.nn.init.ones_(self.weight) # type: ignore + if self.bias is not None: + torch.nn.init.zeros_(self.bias) # type: ignore + + +class LayerNorm(LayerNormBase): + """ + The default :class:`LayerNorm` implementation which can optionally run in low precision. + """ + + def __init__( + self, + config: ModelConfig, + size: Optional[int] = None, + low_precision: bool = False, + elementwise_affine: Optional[bool] = None, + eps: float = 1e-05, + ): + super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps) + self.low_precision = low_precision + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.low_precision: + module_device = x.device + downcast_x = self._cast_if_autocast_enabled(x) + downcast_weight = ( + self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight + ) + downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias + with torch.autocast(enabled=False, device_type=module_device.type): + return F.layer_norm( + downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps + ) + else: + return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps) + + +class RMSLayerNorm(LayerNormBase): + """ + RMS layer norm, a simplified :class:`LayerNorm` implementation + """ + + def __init__( + self, + config: ModelConfig, + size: Optional[int] = None, + elementwise_affine: Optional[bool] = None, + eps: float = 1e-5, + ): + super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + with torch.autocast(enabled=False, device_type=x.device.type): + og_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + x = x.to(og_dtype) + + if self.weight is not None: + if self.bias is not None: + return self.weight * x + self.bias + else: + return self.weight * x + else: + return x + + +class RotaryEmbedding(nn.Module): + """ + [Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864). + """ + + def __init__(self, config: ModelConfig, cache: BufferCache): + super().__init__() + self.config = config + self.__cache = cache + # Warm up cache. + self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config)) + + def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: + if ( + (pos_sin := self.__cache.get("rope_pos_sin")) is not None + and (pos_cos := self.__cache.get("rope_pos_cos")) is not None + and pos_sin.shape[-2] >= seq_len + and pos_cos.shape[-2] >= seq_len + ): + if pos_sin.device != device: + pos_sin = pos_sin.to(device) + self.__cache["rope_pos_sin"] = pos_sin + if pos_cos.device != device: + pos_cos = pos_cos.to(device) + self.__cache["rope_pos_cos"] = pos_cos + return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :] + + with torch.autocast(device.type, enabled=False): + dim = self.config.d_model // self.config.n_heads + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim)) + seq = torch.arange(seq_len, device=device, dtype=torch.float) + freqs = einsum("i , j -> i j", seq, inv_freq) + positions = torch.cat((freqs, freqs), dim=-1) + pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :] + self.__cache["rope_pos_sin"] = pos_sin + self.__cache["rope_pos_cos"] = pos_cos + return pos_sin, pos_cos + + def rotate_half(self, x: torch.Tensor) -> torch.Tensor: + B, nh, T, hs = x.size() + x = x.view(B, nh, T, 2, hs // 2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype) + + def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.config.rope_full_precision: + q_, k_ = q.float(), k.float() + else: + q_, k_ = q, k + + with torch.autocast(q.device.type, enabled=False): + query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None + pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device) + pos_sin = pos_sin.type_as(q_) + pos_cos = pos_cos.type_as(q_) + q_ = self.apply_rotary_pos_emb( + pos_sin[:, :, key_len - query_len : key_len, :], + pos_cos[:, :, key_len - query_len : key_len, :], + q_, + ) + k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_) + return q_.type_as(q), k_.type_as(k) + + +class Activation(nn.Module): + def __init__(self, config: ModelConfig): + super().__init__() + self.config = config + + @abstractmethod + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @property + @abstractmethod + def output_multiplier(self) -> float: + raise NotImplementedError + + @classmethod + def build(cls, config: ModelConfig) -> Activation: + if config.activation_type == ActivationType.gelu: + return cast(Activation, GELU(approximate="none")) + elif config.activation_type == ActivationType.relu: + return cast(Activation, ReLU(inplace=False)) + elif config.activation_type == ActivationType.swiglu: + return SwiGLU(config) + else: + raise NotImplementedError(f"Unknown activation: '{config.activation_type}'") + + +class GELU(nn.GELU): + @property + def output_multiplier(self) -> float: + return 1.0 + + +class ReLU(nn.ReLU): + @property + def output_multiplier(self) -> float: + return 1.0 + + +class SwiGLU(Activation): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, gate = x.chunk(2, dim=-1) + return F.silu(gate) * x + + @property + def output_multiplier(self) -> float: + return 0.5 + + +def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor: + att_bias = torch.triu( + torch.ones(seq_len, seq_len, device=device, dtype=torch.float), + diagonal=1, + ) + att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min) + return att_bias.view(1, 1, seq_len, seq_len) # type: ignore + + +def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor: + if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len: + if causal_bias.device != device: + causal_bias = causal_bias.to(device) + cache["causal_attention_bias"] = causal_bias + return causal_bias + with torch.autocast(device.type, enabled=False): + causal_bias = causal_attention_bias(seq_len, device) + cache["causal_attention_bias"] = causal_bias + return causal_bias + + +def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor: + alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len) + + # shape: (1, 1, seq_len, seq_len) + alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1) + alibi_bias.abs_().mul_(-1) + + # shape: (n_heads,) + m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device) + m.mul_(config.alibi_bias_max / config.n_heads) + + # shape: (1, n_heads, seq_len, seq_len) + return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore + + +class OLMoBlock(nn.Module): + """ + A base class for transformer block implementations. + """ + + def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): + super().__init__() + self.layer_id = layer_id + self.config = config + self.hidden_size = ( + config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model + ) + self.__cache = cache + assert config.d_model % config.n_heads == 0 + + self._activation_checkpoint_fn = None + + # Dropout. + self.dropout = Dropout(config.residual_dropout) + + # Layer norms. + self.k_norm: Optional[LayerNormBase] = None + self.q_norm: Optional[LayerNormBase] = None + if config.attention_layer_norm: + assert config.effective_n_kv_heads is not None + self.k_norm = LayerNormBase.build( + config, + size=(config.d_model // config.n_heads) * config.effective_n_kv_heads, + elementwise_affine=config.attention_layer_norm_with_affine, + ) + self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine) + + # Make sure QKV clip coefficient is positive, otherwise it's not well-defined. + if config.clip_qkv is not None: + assert config.clip_qkv > 0 + + # Activation function. + self.act = Activation.build(config) + assert (self.act.output_multiplier * self.hidden_size) % 1 == 0 + + # Attention output projection. + self.attn_out = nn.Linear( + config.d_model, config.d_model, bias=config.include_bias, device=config.init_device + ) + + # Feed-forward output projection. + self.ff_out = nn.Linear( + int(self.act.output_multiplier * self.hidden_size), + config.d_model, + bias=config.include_bias, + device=config.init_device, + ) + self.ff_out._is_residual = True # type: ignore + + # Rotary embeddings. + if self.config.rope: + self.rotary_emb = RotaryEmbedding(config, self.__cache) + + self.flash_attn_func = None + if config.flash_attention: + try: + from flash_attn import flash_attn_func # type: ignore + + self.flash_attn_func = flash_attn_func + except ModuleNotFoundError: + pass + + def reset_parameters(self): + if self.k_norm is not None: + self.k_norm.reset_parameters() + if self.q_norm is not None: + self.q_norm.reset_parameters() + + if self.config.init_fn == InitFnType.normal: + attn_out_std = ff_out_std = self.config.init_std + cutoff_factor = self.config.init_cutoff_factor + + elif self.config.init_fn == InitFnType.mitchell: + attn_out_std = 1 / (math.sqrt(2 * self.config.d_model * (self.layer_id + 1))) + ff_out_std = 1 / (math.sqrt(2 * self.ff_out.in_features * (self.layer_id + 1))) + cutoff_factor = self.config.init_cutoff_factor or 3.0 + + elif self.config.init_fn == InitFnType.full_megatron: + attn_out_std = ff_out_std = self.config.init_std / math.sqrt(2.0 * self.config.n_layers) + cutoff_factor = self.config.init_cutoff_factor or 3.0 + + else: + raise NotImplementedError(self.config.init_fn) + + init_normal(self.attn_out, std=attn_out_std, init_cutoff_factor=cutoff_factor) + init_normal(self.ff_out, std=ff_out_std, init_cutoff_factor=cutoff_factor) + + def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]): + if strategy == ActivationCheckpointingStrategy.fine_grained: + self._activation_checkpoint_fn = activation_checkpoint_function(self.config) + else: + self._activation_checkpoint_fn = None + + @classmethod + def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor: + target_dtype = input_dtype + # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function + # `is_autocast_cpu_enabled()` for CPU autocast. + # See https://github.com/pytorch/pytorch/issues/110966. + if bias.device.type == "cuda" and torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled(): + target_dtype = torch.get_autocast_cpu_dtype() + if bias.dtype != target_dtype: + bias = bias.to(target_dtype) + ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False) + return bias + + def _scaled_dot_product_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + ) -> torch.Tensor: + """ + Computes scaled dot product attention on query, key and value tensors, using an optional + attention mask if passed, and applying dropout if a probability greater than 0.0 is specified. + """ + if self.flash_attn_func is not None and attn_mask is None: + r = self.flash_attn_func( + q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=is_causal + ) + return r.transpose(1, 2) + else: + # torch's sdpa doesn't support GQA, so we're doing this + assert k.size(1) == v.size(1) + num_kv_heads = k.size(1) + num_q_heads = q.size(1) + if num_q_heads != num_kv_heads: + assert num_q_heads % num_kv_heads == 0 + k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads) + v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads) + + return F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + ) + + def attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attention_bias: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + B, T, C = q.size() # batch size, sequence length, d_model + dtype = k.dtype + + # Optionally apply layer norm to keys and queries. + if self.q_norm is not None and self.k_norm is not None: + q = self.q_norm(q).to(dtype=dtype) + k = self.k_norm(k).to(dtype=dtype) + + # Move head forward to be next to the batch dim. + # shape: (B, nh, T, hs) + q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2) + # shape: (B, n_kv_h, T, hs) + k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2) + # shape: (B, n_kv_h, T, hs) + v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2) + + if layer_past is not None: + past_key, past_value = layer_past + k = torch.cat((past_key, k), dim=-2) + v = torch.cat((past_value, v), dim=-2) + + present = (k, v) if use_cache else None + query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None + + if self.config.rope: + # Apply rotary embeddings. + q, k = self.rotary_emb(q, k) + + if attention_bias is not None: + # Resize and cast attention bias. + # The current dtype of the attention bias might not match the dtype that the SDP attn function will + # run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding + # as down-casting the attention bias to the autocast precision will result in -infs, which will + # cause the SDP attn function to produce NaNs. + attention_bias = self._cast_attn_bias( + attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype + ) + + # Get the attention scores. + # shape: (B, nh, T, hs) + att = self._scaled_dot_product_attention( + q, + k, + v, + attn_mask=attention_bias, + dropout_p=0.0 if not self.training else self.config.attention_dropout, + is_causal=attention_bias is None, + ) + + # Re-assemble all head outputs side-by-side. + att = att.transpose(1, 2).contiguous().view(B, T, C) + + # Apply output projection. + return self.attn_out(att), present + + @abstractmethod + def forward( + self, + x: torch.Tensor, + attention_bias: Optional[torch.FloatTensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + raise NotImplementedError + + @classmethod + def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> OLMoBlock: + if config.block_type == BlockType.sequential: + return OLMoSequentialBlock(layer_id, config, cache) + elif config.block_type == BlockType.llama: + return OLMoLlamaBlock(layer_id, config, cache) + else: + raise NotImplementedError(f"Unknown block type: '{config.block_type}'") + + +class OLMoSequentialBlock(OLMoBlock): + """ + This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): + super().__init__(layer_id, config, cache) + # Layer norms. + self.attn_norm = LayerNorm.build(config) + self.ff_norm = LayerNorm.build(config) + # Attention input projection. Projects x -> (q, k, v) + + head_dim = config.d_model // config.n_heads + self.fused_dims = ( + config.d_model, + config.effective_n_kv_heads * head_dim, + config.effective_n_kv_heads * head_dim, + ) + self.att_proj = nn.Linear( + config.d_model, sum(self.fused_dims), bias=config.include_bias, device=config.init_device + ) + # Feed-forward input projection. + self.ff_proj = nn.Linear( + config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device + ) + + def reset_parameters(self): + super().reset_parameters() + self.attn_norm.reset_parameters() + self.ff_norm.reset_parameters() + # NOTE: the standard deviation for these weights does not depend on the layer. + + if self.config.init_fn == InitFnType.normal: + std = self.config.init_std + cutoff_factor = self.config.init_cutoff_factor + elif self.config.init_fn == InitFnType.mitchell: + std = 1 / math.sqrt(self.config.d_model) + cutoff_factor = self.config.init_cutoff_factor or 3.0 + elif self.config.init_fn == InitFnType.full_megatron: + std = self.config.init_std + cutoff_factor = self.config.init_cutoff_factor or 3.0 + else: + raise NotImplementedError(self.config.init_fn) + + init_normal(self.att_proj, std, cutoff_factor) + init_normal(self.ff_proj, std, cutoff_factor) + + def forward( + self, + x: torch.Tensor, + attention_bias: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Get query, key, value projections. + # shape: + # - for regular attn q, k, v: (batch_size, seq_len, d_model) + # - for multi-query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_heads) + # - for group query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_kv_heads) + if self._activation_checkpoint_fn is not None: + qkv = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)) + else: + qkv = self.att_proj(self.attn_norm(x)) + + if self.config.clip_qkv is not None: + qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + q, k, v = qkv.split(self.fused_dims, dim=-1) + + # Get attention scores. + if self._activation_checkpoint_fn is not None: + att, cache = self._activation_checkpoint_fn( # type: ignore + self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache + ) + else: + att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache) + + # Add attention scores. + # shape: (B, T, C) + x = x + self.dropout(att) + + # Add feed-forward projection. + # shape: (batch_size, seq_len, d_model) + og_x = x + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore + else: + x = self.ff_norm(x) + x = self.ff_proj(x) + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.act, x) # type: ignore + else: + x = self.act(x) + x = self.ff_out(x) + x = self.dropout(x) + x = og_x + x + + return x, cache + + +class OLMoLlamaBlock(OLMoBlock): + """ + This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). This block is similar to `OLMoSequentialBlock` + but some operations have slightly different implementations to imitate the + behavior of Llama. + """ + + def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): + super().__init__(layer_id, config, cache) + # Layer norms. + self.attn_norm = LayerNorm.build(config) + self.ff_norm = LayerNorm.build(config) + self.__cache = cache + + # Attention input projection. Projects x -> (q, k, v) + if config.multi_query_attention: + q_proj_out_dim = config.d_model + k_proj_out_dim = config.d_model // config.n_heads + v_proj_out_dim = config.d_model // config.n_heads + else: + q_proj_out_dim = config.d_model + k_proj_out_dim = config.d_model + v_proj_out_dim = config.d_model + self.q_proj = nn.Linear( + config.d_model, q_proj_out_dim, bias=config.include_bias, device=config.init_device + ) + self.k_proj = nn.Linear( + config.d_model, k_proj_out_dim, bias=config.include_bias, device=config.init_device + ) + self.v_proj = nn.Linear( + config.d_model, v_proj_out_dim, bias=config.include_bias, device=config.init_device + ) + + # Feed-forward input projection. + self.ff_proj = nn.Linear( + config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device + ) + + def reset_parameters(self): + super().reset_parameters() + self.attn_norm.reset_parameters() + self.ff_norm.reset_parameters() + # NOTE: the standard deviation for these weights does not depend on the layer. + + if self.config.init_fn == InitFnType.normal: + std = self.config.init_std + cutoff_factor = self.config.init_cutoff_factor + elif self.config.init_fn == InitFnType.mitchell: + std = 1 / math.sqrt(self.config.d_model) + cutoff_factor = self.config.init_cutoff_factor or 3.0 + elif self.config.init_fn == InitFnType.full_megatron: + std = self.config.init_std + cutoff_factor = self.config.init_cutoff_factor or 3.0 + else: + raise NotImplementedError(self.config.init_fn) + + init_normal(self.q_proj, std, cutoff_factor) + init_normal(self.k_proj, std, cutoff_factor) + init_normal(self.v_proj, std, cutoff_factor) + init_normal(self.ff_proj, std, cutoff_factor) + + def _scaled_dot_product_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + ) -> torch.Tensor: + attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1)) + + if is_causal: + assert attn_mask is None + + query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None + attn_bias = get_causal_attention_bias(self.__cache, key_len, q.device)[:, :, :query_len, :key_len] + elif attn_mask is not None: + attn_bias = attn_mask.to(q.dtype) + else: + attn_bias = torch.zeros_like(attn_weights) + + attn_weights += attn_bias + attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(q.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout_p) + return torch.matmul(attn_weights, v) + + def forward( + self, + x: torch.Tensor, + attention_bias: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Get query, key, value projections. + # shape: + # - for regular attn q, k, v: (batch_size, seq_len, d_model) + # - for multi-query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_heads) + x_normed = self.attn_norm(x) + q = self.q_proj(x_normed) + k = self.k_proj(x_normed) + v = self.v_proj(x_normed) + + if self.config.clip_qkv is not None: + q.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + k.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + v.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + # Get attention scores. + att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache) + + # Add attention scores. + # shape: (B, T, C) + x = x + self.dropout(att) + + # Add feed-forward projection. + # shape: (batch_size, seq_len, d_model) + og_x = x + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore + else: + x = self.ff_norm(x) + x = self.ff_proj(x) + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.act, x) # type: ignore + else: + x = self.act(x) + x = self.ff_out(x) + x = self.dropout(x) + x = og_x + x + + return x, cache + + +class OLMoOutput(NamedTuple): + logits: torch.FloatTensor + """ + A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities + for the next token *before* normalization via (log) softmax. + """ + + attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] + """ + Attention keys and values from each block. + """ + + hidden_states: Optional[Tuple[torch.Tensor]] + """ + Hidden states from each block. + """ + + +class OLMoGenerateOutput(NamedTuple): + token_ids: torch.LongTensor + """ + The generated token IDs, a tensor of shape `(batch_size, beam_size, max_steps)`. + These do *not* include the original input IDs. + """ + + scores: torch.FloatTensor + """ + The scores of the generated sequences, a tensor of shape `(batch_size, beam_size)`. + """ + + +class OLMoBlockGroup(nn.ModuleList): + def __init__(self, config: ModelConfig, layer_offset: int, modules: Optional[Iterable[nn.Module]] = None): + super().__init__(modules) + self.config = config + self.layer_offset = layer_offset + self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None + self._activation_checkpoint_fn = activation_checkpoint_function(self.config) + + def forward( + self, + x: torch.Tensor, + attention_bias: Optional[torch.FloatTensor] = None, + layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: + attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None + for block_idx, block in enumerate(self): + layer_past = None if layers_past is None else layers_past[block_idx] + block_idx += self.layer_offset + if should_checkpoint_block(self.activation_checkpointing_strategy, block_idx): + # shape: (batch_size, seq_len, d_model) + x, cache = self._activation_checkpoint_fn( # type: ignore + block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache + ) + else: + # shape: (batch_size, seq_len, d_model) + x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache) + if attn_key_values is not None: + assert cache is not None + attn_key_values.append(cache) + return x, attn_key_values + + def reset_parameters(self): + for block in self: + block.reset_parameters() + + def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]): + self.activation_checkpointing_strategy = strategy + for block in self: + block.set_activation_checkpointing(strategy) + + +class OLMo(nn.Module): + def __init__(self, config: ModelConfig, init_params: bool = True): + super().__init__() + self.config = config + self.__cache = BufferCache() + + # Validate config. + if self.config.alibi and self.config.flash_attention: + raise OLMoConfigurationError("ALiBi is currently not supported with FlashAttention") + + if self.config.alibi and self.config.rope: + raise OLMoConfigurationError("ALiBi and RoPE are mutually exclusive") + + if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size: + if self.config.embedding_size < self.config.vocab_size: + raise OLMoConfigurationError("embedding size should be at least as big as vocab size") + elif self.config.embedding_size % 128 != 0: + import warnings + + warnings.warn( + "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning + ) + + self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None + self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config) + + if not ( + 0 < self.config.block_group_size <= self.config.n_layers + and self.config.n_layers % self.config.block_group_size == 0 + ): + raise OLMoConfigurationError("n layers must be divisible by block group size") + + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it + + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding( + config.embedding_size or config.vocab_size, config.d_model, device=config.init_device + ), + emb_drop=Dropout(config.embedding_dropout), + ln_f=LayerNorm.build(config), + ) + ) + + blocks = [OLMoBlock.build(i, config, self.__cache) for i in range(config.n_layers)] + if self.config.block_group_size > 1: + block_groups = [ + OLMoBlockGroup(config, i, blocks[i : i + config.block_group_size]) + for i in range(0, config.n_layers, config.block_group_size) + ] + self.transformer.update({"block_groups": nn.ModuleList(block_groups)}) + else: + self.transformer.update({"blocks": nn.ModuleList(blocks)}) + + if not (self.config.alibi or self.config.rope): + self.transformer.update( + {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)} + ) + if not config.weight_tying: + self.transformer.update( + { + "ff_out": nn.Linear( + config.d_model, + config.embedding_size or config.vocab_size, + bias=config.include_bias, + device=config.init_device, + ) + } + ) + # When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights. + if init_params and self.config.init_device != "meta": + self.reset_parameters() + self.__num_fwd_flops: Optional[int] = None + + # Warm up cache. + if self.config.alibi: + get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config)) + self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config)) + + def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]): + self.activation_checkpointing_strategy = strategy + if self.config.block_group_size != 1: + for block_group in self.transformer.block_groups: + block_group.set_activation_checkpointing(strategy) + else: + for block in self.transformer.blocks: + block.set_activation_checkpointing(strategy) + + @property + def device(self) -> torch.device: + device: torch.device = self.transformer.wte.weight.device # type: ignore + if device.type == "meta": + return _non_meta_init_device(self.config) + else: + return device + + def reset_parameters(self): + log.info("Initializing model parameters...") + # Top-level embeddings / linear layers. + + if self.config.init_fn == InitFnType.normal: + # Note: We may potentially want to multiply the std by a factor of sqrt(d) in case of `scale_logits` + # and `weight_tying`. However, we are currently not using either, and may need to rethink the init logic + # if/when we do want it. + wte_std = self.config.init_std + wte_cutoff_factor = self.config.init_cutoff_factor + elif self.config.init_fn == InitFnType.mitchell: + wte_std = 1.0 / math.sqrt(self.config.d_model) + wte_cutoff_factor = self.config.init_cutoff_factor or 3.0 + elif self.config.init_fn == InitFnType.full_megatron: + wte_std = self.config.init_std + wte_cutoff_factor = self.config.init_cutoff_factor or 3.0 + else: + raise NotImplementedError(self.config.init_fn) + + init_normal(self.transformer.wte, std=wte_std, init_cutoff_factor=wte_cutoff_factor) + + if hasattr(self.transformer, "wpe"): + if self.config.init_fn == InitFnType.normal: + wpe_std = self.config.init_std + wpe_cutoff_factor = self.config.init_cutoff_factor + elif self.config.init_fn == InitFnType.mitchell: + wpe_std = 1 / math.sqrt(self.config.d_model) + wpe_cutoff_factor = self.config.init_cutoff_factor or 3.0 + elif self.config.init_fn == InitFnType.full_megatron: + wpe_std = self.config.init_std + wpe_cutoff_factor = self.config.init_cutoff_factor or 3.0 + else: + raise NotImplementedError(self.config.init_fn) + + init_normal(self.transformer.wpe, std=wpe_std, init_cutoff_factor=wpe_cutoff_factor) + + # Top-level layer norm. + self.transformer.ln_f.reset_parameters() # type: ignore + + # Output weights. + if hasattr(self.transformer, "ff_out"): + if self.config.init_fn == InitFnType.normal: + ff_out_std = self.config.init_std + ff_out_cutoff_factor = self.config.init_cutoff_factor + elif self.config.init_fn == InitFnType.mitchell: + ff_out_std = 1 / math.sqrt(self.config.d_model) + ff_out_cutoff_factor = self.config.init_cutoff_factor or 3.0 + elif self.config.init_fn == InitFnType.full_megatron: + ff_out_std = 1 / math.sqrt(self.config.d_model) + ff_out_cutoff_factor = self.config.init_cutoff_factor or 3.0 + else: + raise NotImplementedError(self.config.init_fn) + + init_normal(self.transformer.ff_out, ff_out_std, ff_out_cutoff_factor) + + # Let the blocks handle themselves. + if self.config.block_group_size == 1: + for block in self.transformer.blocks: + block.reset_parameters() + else: + for block_group in self.transformer.block_groups: + block_group.reset_parameters() + + def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor: + if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[ + -1 + ] >= seq_len: + if alibi_bias.device != device: + alibi_bias = alibi_bias.to(device) + self.__cache["alibi_attention_bias"] = alibi_bias + return alibi_bias + with torch.autocast(device.type, enabled=False): + alibi_bias = alibi_attention_bias(seq_len, self.config, device) + self.__cache["alibi_attention_bias"] = alibi_bias + return alibi_bias + + def forward( + self, + input_ids: torch.LongTensor, + input_embeddings: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, + past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None, + use_cache: bool = False, + last_logits_only: bool = False, + output_hidden_states: Optional[bool] = None, + ) -> OLMoOutput: + """ + :param input_ids: A tensor of shape `(batch_size, seq_len)`. + :param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input + embeddings. When provided, it is treated as the output of the input embedding layer. + :param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates + which input IDs are masked. A `1` value in the mask means that + the corresponding input ID should *not* be ignored. A `0` means + that the corresponding input ID is masked. + + This has the same meaning as the `attention_mask` in HuggingFace's `transformers` + library. + :param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`, + `(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used + to introduce causal or other biases. + + If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]` + indicates that the i-th element in the sequence is allowed to attend to the j-th + element in the sequence. + + If the tensor is a float tensor, it will just be added to the attention + scores before the softmax. + + The default is causal, which corresponds to a lower-diagonal byte matrix of ones. + :param past_key_values: Pre-computed keys and values for each attention block. + Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + :param use_cache: If `True`, return key and value tensors for each block. + :param last_logits_only: If `True`, only compute the logits for the last token of each sequence. + This can speed up decoding when you only care about the next token. + """ + output_hidden_states = output_hidden_states if output_hidden_states is not None else False + + if past_key_values: + assert len(past_key_values) == self.config.n_layers + + batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2] + if past_key_values is None: + past_length = 0 + else: + past_length = past_key_values[0][0].size(-2) + + # Get embeddings of input. + # shape: (batch_size, seq_len, d_model) + x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore + + if not (self.config.alibi or self.config.rope): + # Get positional embeddings. + # shape: (1, seq_len) + pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0) + # shape: (1, seq_len, d_model) + pos_emb = self.transformer.wpe(pos) # type: ignore + x = pos_emb + x + + # Add input + positional embeddings and apply dropout. + # shape: (batch_size, seq_len, d_model) + x = self.transformer.emb_drop(x) # type: ignore + + # Transform the attention mask into what the blocks expect. + if attention_mask is not None: + # shape: (batch_size, 1, 1, seq_len) + attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :] + attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min + + # Merge attention mask with attention bias. + if ( + attention_bias is not None + or attention_mask is not None + or self.config.alibi + # NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly + # with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute + # scores correctly. + or past_key_values is not None + ): + if attention_bias is None and self.config.alibi: + attention_bias = get_causal_attention_bias( + self.__cache, past_length + seq_len, x.device + ) + self.get_alibi_attention_bias(past_length + seq_len, x.device) + elif attention_bias is None: + attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device) + elif attention_bias.dtype in (torch.int8, torch.bool): + attention_bias = attention_bias.to(dtype=torch.float) + attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min) + + # Transform to the right shape and data type. + mask_len = seq_len + if attention_mask is not None: + mask_len = attention_mask.shape[-1] + elif past_key_values is not None: + mask_len = past_key_values[0][0].shape[-2] + seq_len + attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float) + + # Add in the masking bias. + if attention_mask is not None: + attention_bias = attention_bias + attention_mask + # Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf. + # `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead + # it can produce NaNs. + ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False) + + attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None + + # decoder layers + all_hidden_states = [] + + # Apply blocks one-by-one. + if self.config.block_group_size == 1: + for block_idx, block in enumerate(self.transformer.blocks): + if output_hidden_states: + # add hidden states + all_hidden_states.append(x) + + layer_past = None if past_key_values is None else past_key_values[block_idx] + if should_checkpoint_block(self.activation_checkpointing_strategy, block_idx): + # shape: (batch_size, seq_len, d_model) + x, cache = self._activation_checkpoint_fn( + block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache + ) + else: + # shape: (batch_size, seq_len, d_model) + x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache) + + if attn_key_values is not None: + assert cache is not None + attn_key_values.append(cache) + else: + for group_idx, block_group in enumerate(self.transformer.block_groups): + if output_hidden_states: + # add hidden states + all_hidden_states.append(x) + + layers_past = ( + None + if past_key_values is None + else past_key_values[ + group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size + ] + ) + x, cache = block_group( + x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache + ) + if attn_key_values is not None: + assert cache is not None + attn_key_values.extend(cache) + + if last_logits_only: + # shape: (batch_size, 1, d_model) + x = x[:, -1, :].unsqueeze(1) + + # Apply final layer norm. + # shape: (batch_size, seq_len or 1, d_model) + x = self.transformer.ln_f(x) # type: ignore + if output_hidden_states: + # add final hidden state post-final-layernorm, following HuggingFace's convention + all_hidden_states.append(x) + + # Get logits. + # shape: (batch_size, seq_len or 1, vocab_size) + if self.config.weight_tying: + logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore + else: + logits = self.transformer.ff_out(x) # type: ignore + if self.config.scale_logits: + logits.mul_(1 / math.sqrt(self.config.d_model)) + return logits + #return OLMoOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type] + + def get_fsdp_wrap_policy(self, wrap_strategy: Optional[FSDPWrapStrategy] = None): + if wrap_strategy is None: + return None + + # The 'recurse' mode for the wrap function does not behave like you'd expect. + # Even if we return False, it may still recurse because PyTorch does what it wants, + # not what you want. This causes issues when, for example, we want to wrap 'ff_out' (a linear layer) + # but not other linear layers within a block. + # So we have to explicitly tell PyTorch which linear layers to wrap, and we also just + # return True in 'recurse' mode for simplicity. + size_based_module_to_wrap = {self.transformer.wte} + if hasattr(self.transformer, "ff_out"): + size_based_module_to_wrap.add(self.transformer.ff_out) + + if wrap_strategy == FSDPWrapStrategy.by_block: + + def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): + del nonwrapped_numel + wrap = isinstance(module, OLMoBlock) + if recurse: + return True + else: + return wrap + + return fsdp_wrap_fn + elif wrap_strategy == FSDPWrapStrategy.by_block_and_size: + + def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): + del nonwrapped_numel + wrap = isinstance(module, (OLMoBlock,)) or module in size_based_module_to_wrap + if recurse: + return True + else: + return wrap + + return fsdp_wrap_fn + elif wrap_strategy == FSDPWrapStrategy.by_block_group: + if self.config.block_group_size <= 1: + raise OLMoConfigurationError( + "'by_block_group' FSDP wrapping strategy requires block group size greater than 1" + ) + + def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): + del nonwrapped_numel + wrap = isinstance(module, OLMoBlockGroup) + if recurse: + return True + else: + return wrap + + return fsdp_wrap_fn + elif wrap_strategy == FSDPWrapStrategy.by_block_group_and_size: + if self.config.block_group_size <= 1: + raise OLMoConfigurationError( + "'by_block_group_and_size' FSDP wrapping strategy requires block group size greater than 1" + ) + + def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): + del nonwrapped_numel + wrap = isinstance(module, (OLMoBlockGroup,)) or module in size_based_module_to_wrap + if recurse: + return True + else: + return wrap + + return fsdp_wrap_fn + elif wrap_strategy == FSDPWrapStrategy.size_based: + from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy + + return size_based_auto_wrap_policy + elif wrap_strategy in { + FSDPWrapStrategy.one_in_two, + FSDPWrapStrategy.one_in_three, + FSDPWrapStrategy.one_in_four, + FSDPWrapStrategy.one_in_five, + }: + c = { + FSDPWrapStrategy.one_in_two: 2, + FSDPWrapStrategy.one_in_three: 3, + FSDPWrapStrategy.one_in_four: 4, + FSDPWrapStrategy.one_in_five: 5, + }[wrap_strategy] + + def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): + del nonwrapped_numel + wrap = isinstance(module, OLMoBlock) and module.layer_id % c == 0 + if recurse: + return True + else: + return wrap + + return fsdp_wrap_fn + else: + raise NotImplementedError(wrap_strategy) + + def num_params(self, include_embedding: bool = True) -> int: + """ + Get the total number of parameters. + """ + params = (np for np in self.named_parameters()) + if not include_embedding: + params = filter( # type: ignore + lambda np: ".wte." not in np[0] and ".wpe." not in np[0], + params, + ) + return sum(p.numel() for _, p in params) + + @property + def num_fwd_flops(self): + if self.__num_fwd_flops: + return self.__num_fwd_flops + n_params = self.num_params() + # the number of parameters is approximately the number of multiply-accumulates (MAC) in the network + # each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param + # this gets us FLOPs / token + params_flops_per_token = 2 * n_params + params_flops_per_seq = params_flops_per_token * self.config.max_sequence_length + # there are 2 FLOPS per mac; there is A=Q*K^T and out=A*V ops (ie mult by 2) + attn_flops_per_seq = ( + self.config.n_layers * 2 * 2 * (self.config.d_model * (self.config.max_sequence_length**2)) + ) + self.__num_fwd_flops = params_flops_per_seq + attn_flops_per_seq + return self.__num_fwd_flops + + def generate( + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, + max_steps: int = 10, + beam_size: int = 1, + per_node_beam_size: Optional[int] = None, + sampler: Optional[Sampler] = None, + min_steps: Optional[int] = None, + final_sequence_scorer: Optional[FinalSequenceScorer] = None, + constraints: Optional[List[Constraint]] = None, + ) -> OLMoGenerateOutput: + """ + Generate token IDs using beam search. + + Note that by default ``beam_size`` is set to 1, which is greedy decoding. + + :param input_ids: A tensor of shape `(batch_size, seq_len)`. + :param attention_mask: A optional tensor of shape `(batch_size, seq_len)`, the same + as for the forward method. + :param attention_bias: A tensor of shape + `(batch_size, 1, seq_len + tokens_to_generate, seq_len + tokens_to_generate)`, + the same as for the forward method except only one shape is excepted here. + + For an explanation of the other arguments, see :class:`BeamSearch`. + """ + beam_search = BeamSearch( + self.config.eos_token_id, + max_steps=max_steps, + beam_size=beam_size, + per_node_beam_size=per_node_beam_size, + sampler=sampler, + min_steps=min_steps, + final_sequence_scorer=final_sequence_scorer, + constraints=constraints, + ) + + # Validate inputs. + batch_size, seq_len = input_ids.shape + if attention_mask is not None: + assert attention_mask.shape == (batch_size, seq_len) + if attention_bias is not None: + assert len(attention_bias.shape) == 4 + assert attention_bias.shape[:2] == (batch_size, 1) + assert ( + seq_len + beam_search.max_steps + <= attention_bias.shape[2] + == attention_bias.shape[3] + <= self.config.max_sequence_length + ) + + tokens_generated = 0 + + def flatten_past_key_values( + past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> Dict[str, torch.Tensor]: + out = {} + for i, (key, value) in enumerate(past_key_values): + out[f"past_key_{i}"] = key + out[f"past_value_{i}"] = value + return out + + def unflatten_past_key_values( + past_key_values: Dict[str, torch.Tensor], + ) -> List[Tuple[torch.Tensor, torch.Tensor]]: + out = [] + for i in range(self.config.n_layers): + past_key = past_key_values[f"past_key_{i}"] + past_value = past_key_values[f"past_value_{i}"] + out.append((past_key, past_value)) + return out + + def step( + last_predictions: torch.Tensor, state: dict[str, torch.Tensor] + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + nonlocal tokens_generated + + attention_mask = state.get("attention_mask") + attention_bias = state.get("attention_bias") + + if tokens_generated > 0: + past_key_values = unflatten_past_key_values(state) + input_ids = last_predictions.unsqueeze(1) + if attention_mask is not None: + group_size = input_ids.shape[0] + attention_mask = torch.cat((attention_mask, attention_mask.new_ones((group_size, 1))), dim=-1) + else: + past_key_values = None + input_ids = state["input_ids"] + + tokens_generated += 1 + + # Run forward pass of model to get logits, then normalize to get log probs. + output = self( + input_ids, + attention_mask=attention_mask, + attention_bias=attention_bias, + past_key_values=past_key_values, + use_cache=True, + last_logits_only=True, + ) + log_probs = F.log_softmax(output.logits[:, -1, :], dim=-1) + + # Create new state. + state = flatten_past_key_values(output.attn_key_values) + if attention_mask is not None: + state["attention_mask"] = attention_mask + if attention_bias is not None: + state["attention_bias"] = attention_bias + + return log_probs, state + + initial_preds = input_ids.new_zeros((batch_size,)) # This is arbitrary, we won't use this. + state: dict[str, torch.Tensor] = {"input_ids": input_ids} + if attention_mask is not None: + state["attention_mask"] = attention_mask + if attention_bias is not None: + state["attention_bias"] = attention_bias + with torch.no_grad(): + token_ids, scores = beam_search.search(initial_preds, state, step) + + return OLMoGenerateOutput( + token_ids=token_ids, # type: ignore[arg-type] + scores=scores, # type: ignore[arg-type] + ) + + @classmethod + def from_checkpoint( + cls, checkpoint_dir: PathOrStr, device: str = "cpu", checkpoint_type: Optional[CheckpointType] = None + ) -> OLMo: + """ + Load an OLMo model from a checkpoint. + """ + from .util import resource_path + + # Guess checkpoint type. + if checkpoint_type is None: + try: + if resource_path(checkpoint_dir, "model.pt").is_file(): + checkpoint_type = CheckpointType.unsharded + else: + checkpoint_type = CheckpointType.sharded + except FileNotFoundError: + checkpoint_type = CheckpointType.sharded + + # Load config. + config_path = resource_path(checkpoint_dir, "config.yaml") + model_config = ModelConfig.load(config_path, key="model", validate_paths=False) + + if checkpoint_type == CheckpointType.unsharded: + # Initialize model (always on CPU to start with so we don't run out of GPU memory). + model_config.init_device = "cpu" + model = OLMo(model_config) + + # Load state dict directly to target device. + state_dict_path = resource_path(checkpoint_dir, "model.pt") + state_dict = torch.load(state_dict_path, map_location="cpu") + model.load_state_dict(model._make_state_dict_compatible(state_dict)[0]) + model = model.to(torch.device(device)) + else: + from .checkpoint import load_model_state + + # Initialize model on target device. In this case the state dict is loaded in-place + # so it's not necessary to start on CPU if the target device is a GPU. + model_config.init_device = device + model = OLMo(model_config) + + # Load state dict in place. + load_model_state(checkpoint_dir, model) + + return model.eval() + + def _make_state_dict_compatible( + self, state_dict: Dict[str, torch.Tensor] + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Set[str]]]: + """ + Handles some cases where the state dict is valid yet may need to be transformed in order to + be loaded. + + This modifies the state dict in-place and also returns it, along with a mapping of original key + names to new key names in cases where the keys were simply renamed. That mapping can be used + to make a corresponding optimizer state dict compatible as well. + """ + import re + from fnmatch import fnmatch + + new_keys_to_og_keys: Dict[str, str] = {} + + # Remove "_fsdp_wrapped_module." prefix from all keys. We don't want this prefix when the model is + # not wrapped in FSDP. And when the model is wrapped in FSDP, loading this state dict will still work + # fine without the prefixes. This also simplifies the other steps below. + for key in list(state_dict.keys()): + state_dict[(new_key := key.replace("_fsdp_wrapped_module.", ""))] = state_dict.pop(key) + new_keys_to_og_keys[new_key] = key + + # For backwards compatibility prior to fixing https://github.com/allenai/LLM/issues/222 + if self.config.block_type == BlockType.sequential: + for key in list(state_dict.keys()): + if fnmatch(key, "transformer.*.norm.weight"): + tensor = state_dict.pop(key) + state_dict[(new_key := key.replace("norm.weight", "attn_norm.weight"))] = tensor + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + state_dict[(new_key := key.replace("norm.weight", "ff_norm.weight"))] = tensor.clone() + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + del new_keys_to_og_keys[key] + elif fnmatch(key, "transformer.*.norm.bias"): + tensor = state_dict.pop(key) + state_dict[(new_key := key.replace("norm.bias", "attn_norm.bias"))] = tensor + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + state_dict[(new_key := key.replace("norm.bias", "ff_norm.bias"))] = tensor.clone() + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + del new_keys_to_og_keys[key] + + # For loading a state dict that was saved with a different `block_group_size`. + if "transformer.block_groups.0.0.attn_out.weight" in state_dict.keys(): + state_dict_block_group_size = len( + [k for k in state_dict.keys() if fnmatch(k, "transformer.block_groups.0.*.attn_out.weight")] + ) + else: + state_dict_block_group_size = 1 + if self.config.block_group_size != state_dict_block_group_size: + log.info( + f"Regrouping state dict blocks from group size {state_dict_block_group_size} to " + f"group size {self.config.block_group_size}" + ) + # For simplicity we're first going to flatten out the block groups in the state dict (if necessary) + # and then (re-)group them into the right block sizes. + if state_dict_block_group_size > 1: + for key in list(state_dict.keys()): + if (m := re.match(r"transformer.block_groups\.(\d+)\.(\d+)\..*", key)) is not None: + group_idx, group_block_idx = int(m.group(1)), int(m.group(2)) + block_idx = (group_idx * state_dict_block_group_size) + group_block_idx + state_dict[ + ( + new_key := key.replace( + f"block_groups.{group_idx}.{group_block_idx}.", f"blocks.{block_idx}." + ) + ) + ] = state_dict.pop(key) + new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key) + + if self.config.block_group_size > 1: + # Group the state dict blocks into the right block size. + for key in list(state_dict.keys()): + if (m := re.match(r"transformer.blocks\.(\d+)\..*", key)) is not None: + block_idx = int(m.group(1)) + group_idx, group_block_idx = ( + block_idx // self.config.block_group_size, + block_idx % self.config.block_group_size, + ) + state_dict[ + ( + new_key := key.replace( + f"blocks.{block_idx}.", f"block_groups.{group_idx}.{group_block_idx}." + ) + ) + ] = state_dict.pop(key) + new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key) + + og_keys_to_new: Dict[str, Set[str]] = defaultdict(set) + for new_key, og_key in new_keys_to_og_keys.items(): + og_keys_to_new[og_key].add(new_key) + + return state_dict, og_keys_to_new \ No newline at end of file diff --git a/3.test_cases/neuronx-distributed/olmo/olmo/tokenizer.py b/3.test_cases/neuronx-distributed/olmo/olmo/tokenizer.py new file mode 100644 index 00000000..25a94be5 --- /dev/null +++ b/3.test_cases/neuronx-distributed/olmo/olmo/tokenizer.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +import os +from pathlib import Path +from typing import List, Optional, Union + +from tokenizers import Tokenizer as BaseTokenizer + +from .aliases import PathOrStr +from .config import ModelConfig, TokenizerConfig, TrainConfig, TruncationDirection +from .exceptions import OLMoConfigurationError + +__all__ = ["Tokenizer"] + + +class Tokenizer: + """ + A :class:`Tokenizer` is a light-weight wrapper around a HuggingFace :class:`tokenizers.Tokenizer`. + + :param base_tokenizer: The :class:`tokenizers.Tokenizer` to use. + :param eos_token_id: The token ID corresponding to the "end-of-sentence" token. + :param truncate_to: Truncate when tokenizing to this number of token IDs. + :param truncate_direction: The direction to truncate in. "right" means truncate the tokens + on the right. "left" means truncate the tokens on the left. If ``truncate_to`` is null, + this setting has no effect. + """ + + def __init__( + self, + base_tokenizer: BaseTokenizer, + eos_token_id: int, + pad_token_id: Optional[int] = None, + truncate_to: Optional[int] = None, + truncate_direction: Union[str, TruncationDirection] = TruncationDirection.right, + ): + self.base_tokenizer = base_tokenizer + self.base_tokenizer.no_truncation() + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id if pad_token_id is not None else eos_token_id + self.truncate_to = truncate_to + self.truncate_direction = TruncationDirection(truncate_direction) + + @property + def vocab_size(self) -> int: + return self.base_tokenizer.get_vocab_size() + + @property + def eos_token(self) -> str: + return self.decode([self.eos_token_id], skip_special_tokens=False) + + @property + def pad_token(self) -> str: + return self.decode([self.pad_token_id], skip_special_tokens=False) + + @classmethod + def from_train_config(cls, config: TrainConfig) -> Tokenizer: + tokenizer_identifier = config.tokenizer.identifier + if Path(tokenizer_identifier).is_file(): + tokenizer = cls.from_file( + tokenizer_identifier, + eos_token_id=config.model.eos_token_id, + pad_token_id=config.model.pad_token_id, + ) + else: + tokenizer = cls.from_pretrained( + tokenizer_identifier, + eos_token_id=config.model.eos_token_id, + pad_token_id=config.model.pad_token_id, + ) + if config.model.vocab_size != tokenizer.vocab_size: + raise OLMoConfigurationError("vocab size mismatch between config and tokenizer") + return tokenizer + + @classmethod + def from_pretrained(cls, identifier: str, **kwargs) -> Tokenizer: + """ + Initialize a tokenizer from a pretrained tokenizer on the HuggingFace Hub. + + :param identifier: The identifier of a model on the Hub that contains a + ``tokenizer.json`` file. + :param kwargs: Other key word arguments passed to :class:`Tokenizer`. + """ + base_tokenizer = BaseTokenizer.from_pretrained(identifier) + eos_token_id = kwargs.pop("eos_token_id", base_tokenizer.get_vocab_size() - 1) + return cls(base_tokenizer, eos_token_id, **kwargs) + + @classmethod + def from_file(cls, filename: PathOrStr, **kwargs) -> Tokenizer: + """ + Initialize a tokenizer from a file. + + You can create those files with ``BaseTokenizer.save()``. + + :param filename: The name of a file containing a tokenizer specification. + :param kwargs: Other key word arguments passed to :class:`Tokenizer`. + """ + base_tokenizer = BaseTokenizer.from_file(filename) + eos_token_id = kwargs.pop("eos_token_id", base_tokenizer.get_vocab_size() - 1) + return cls(base_tokenizer, eos_token_id, **kwargs) + + @classmethod + def from_checkpoint(cls, checkpoint_dir: PathOrStr) -> Tokenizer: + """ + Load a tokenizer from a checkpoint. + """ + from cached_path import cached_path + + # Load configs. + config_path = cached_path(os.path.join(checkpoint_dir, "config.yaml")) + tokenizer_config = TokenizerConfig.load(config_path, key="tokenizer") + model_config = ModelConfig.load(config_path, key="model") + + # Initialize tokenizer and validate vocab size. + if Path(tokenizer_config.identifier).is_file(): + tokenizer = cls.from_file( + tokenizer_config.identifier, + eos_token_id=model_config.eos_token_id, + pad_token_id=model_config.pad_token_id, + ) + else: + tokenizer = cls.from_pretrained( + tokenizer_config.identifier, + eos_token_id=model_config.eos_token_id, + pad_token_id=model_config.pad_token_id, + ) + if model_config.vocab_size != tokenizer.vocab_size: + raise OLMoConfigurationError("vocab size mismatch between config and tokenizer") + return tokenizer + + def add_special_tokens(self, input_ids: List[int]) -> List[int]: + """ + Add special tokens in-place (if not already present) to the given token IDs. + """ + if not input_ids or input_ids[-1] != self.eos_token_id: + input_ids.append(self.eos_token_id) + return input_ids + + def num_special_tokens_to_add(self, is_pair: bool = False) -> int: + return 2 if is_pair else 1 + + def _truncate( + self, input_ids: List[int], truncate_to: Optional[int], direction: TruncationDirection + ) -> list[int]: + if truncate_to is None or len(input_ids) <= truncate_to: + return input_ids + elif direction == TruncationDirection.left: + return input_ids[len(input_ids) - truncate_to :] + else: + return input_ids[: -(len(input_ids) - truncate_to)] + + def encode(self, input: str, add_special_tokens: bool = True) -> List[int]: + """ + Encode a string into token IDs. + """ + return self.encode_batch([input], add_special_tokens=add_special_tokens)[0] + + def encode_batch(self, inputs: List[str], add_special_tokens: bool = True) -> List[List[int]]: + """ + Encode a batch of strings into token IDs. + """ + truncate_to = self.truncate_to + if truncate_to is not None and add_special_tokens: + truncate_to -= self.num_special_tokens_to_add(False) + + batch_encoding = self.base_tokenizer.encode_batch(inputs) + + all_input_ids = [] + for encoding in batch_encoding: + input_ids = self._truncate(encoding.ids, truncate_to, self.truncate_direction) + if add_special_tokens: + input_ids = self.add_special_tokens(input_ids) + all_input_ids.append(input_ids) + + return all_input_ids + + def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str: + """ + Decode a list of token IDs to a string. + """ + return self.base_tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) \ No newline at end of file diff --git a/3.test_cases/neuronx-distributed/olmo/olmo/torch_util.py b/3.test_cases/neuronx-distributed/olmo/olmo/torch_util.py new file mode 100644 index 00000000..26bffdde --- /dev/null +++ b/3.test_cases/neuronx-distributed/olmo/olmo/torch_util.py @@ -0,0 +1,142 @@ +import gc +import os +from typing import Optional, TypeVar + +import torch +import torch.distributed as dist + +T = TypeVar("T") + + +def seed_all(seed: int): + """Seed all rng objects.""" + import random + + import numpy as np + + if seed < 0 or seed > 2**32 - 1: + raise ValueError(f"Seed {seed} is invalid. It must be on [0; 2^32 - 1]") + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + # torch.manual_seed may call manual_seed_all but calling it again here + # to make sure it gets called at least once + torch.cuda.manual_seed_all(seed) + + +def is_distributed() -> bool: + return dist.is_available() and dist.is_initialized() + + +def get_node_rank() -> int: + return int(os.environ.get("NODE_RANK") or (get_global_rank() - get_local_rank()) // get_local_world_size()) + + +def get_world_size() -> int: + if is_distributed(): + return dist.get_world_size() + else: + return 1 + + +def get_local_world_size() -> int: + return int(os.environ.get("LOCAL_WORLD_SIZE") or 1) + + +def get_global_rank() -> int: + if is_distributed(): + return int(os.environ.get("RANK") or dist.get_rank()) + else: + return 0 + + +def get_local_rank() -> int: + return int(os.environ.get("LOCAL_RANK") or 0) + + +def get_fs_local_rank() -> int: + """Get the local rank per filesystem, meaning that, regardless of the number of nodes, + if all ranks share the same filesystem then `get_fs_local_rank()` will be equivalent to `get_global_rank()`, + but if nodes do not share the same filesystem then `get_fs_local_rank()` will be equivalent to `get_local_rank()`. + """ + return int(os.environ.get("FS_LOCAL_RANK") or get_local_rank()) + + +def move_to_device(o: T, device: torch.device) -> T: + if isinstance(o, torch.Tensor): + return o.to(device) # type: ignore[return-value] + elif isinstance(o, dict): + return {k: move_to_device(v, device) for k, v in o.items()} # type: ignore[return-value] + elif isinstance(o, list): + return [move_to_device(x, device) for x in o] # type: ignore[return-value] + elif isinstance(o, tuple): + return tuple((move_to_device(x, device) for x in o)) # type: ignore[return-value] + else: + return o + + +def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False): + """ + Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf`` + is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``. + """ + if check_neg_inf: + x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min) + if check_pos_inf: + x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max) + + +def get_default_device() -> torch.device: + if torch.cuda.is_available() and torch.cuda.is_initialized(): + return torch.device("cuda") + else: + return torch.device("cpu") + + +def barrier() -> None: + if is_distributed(): + dist.barrier() + + +def peak_gpu_memory(reset: bool = False) -> Optional[float]: + """ + Get the peak GPU memory usage in MB across all ranks. + Only rank 0 will get the final result. + """ + if not torch.cuda.is_available(): + return None + + device = torch.device("cuda") + peak_mb = torch.cuda.max_memory_allocated(device) / 1000000 + if is_distributed(): + peak_mb_tensor = torch.tensor(peak_mb, device=device) + dist.reduce(peak_mb_tensor, 0, dist.ReduceOp.MAX) + peak_mb = peak_mb_tensor.item() + + if reset: + # Reset peak stats. + torch.cuda.reset_max_memory_allocated(device) + + return peak_mb + + +V = TypeVar("V", bool, int, float) + + +def synchronize_value(value: V, device: torch.device) -> V: + if dist.is_available() and dist.is_initialized(): + value_tensor = torch.tensor(value, device=device) + dist.broadcast(value_tensor, 0) + return value_tensor.item() # type: ignore + else: + return value + + +def synchronize_flag(flag: bool, device: torch.device) -> bool: + return synchronize_value(flag, device) + + +def gc_cuda(): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() \ No newline at end of file diff --git a/3.test_cases/neuronx-distributed/olmo/olmo/util.py b/3.test_cases/neuronx-distributed/olmo/olmo/util.py new file mode 100644 index 00000000..b5c0b1de --- /dev/null +++ b/3.test_cases/neuronx-distributed/olmo/olmo/util.py @@ -0,0 +1,761 @@ +import io +import logging +import os +import re +import socket +import sys +import time +import warnings +from datetime import datetime +from enum import Enum +from itertools import cycle, islice +from pathlib import Path +from queue import Queue +from threading import Thread +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import boto3 +import botocore.exceptions as boto_exceptions +import rich +from botocore.config import Config +from cached_path.schemes import SchemeClient, add_scheme_client +from rich.console import Console, ConsoleRenderable +from rich.highlighter import NullHighlighter +from rich.progress import Progress +from rich.text import Text +from rich.traceback import Traceback + +from .aliases import PathOrStr +from .exceptions import ( + OLMoCliError, + OLMoEnvironmentError, + OLMoError, + OLMoNetworkError, + OLMoThreadError, +) +from .torch_util import get_global_rank, get_local_rank, get_node_rank, is_distributed + +try: + from functools import cache +except ImportError: + from functools import lru_cache as cache + + +class StrEnum(str, Enum): + """ + This is equivalent to Python's :class:`enum.StrEnum` since version 3.11. + We include this here for compatibility with older version of Python. + """ + + def __str__(self) -> str: + return self.value + + def __repr__(self) -> str: + return f"'{str(self)}'" + + +_log_extra_fields: Dict[str, Any] = {} +log = logging.getLogger(__name__) + + +class LogFilterType(StrEnum): + rank0_only = "rank0_only" + local_rank0_only = "local_rank0_only" + all_ranks = "all_ranks" + + +def log_extra_field(field_name: str, field_value: Any) -> None: + global _log_extra_fields + if field_value is None: + if field_name in _log_extra_fields: + del _log_extra_fields[field_name] + else: + _log_extra_fields[field_name] = field_value + + +def setup_logging(log_filter_type: LogFilterType = LogFilterType.rank0_only) -> None: + """ + :param rank0_only: INFO and below messages will only be emitted on the rank0 process. + """ + log_extra_field("hostname", socket.gethostname()) + if is_distributed(): + log_extra_field("node_rank", get_node_rank()) + log_extra_field("local_rank", get_local_rank()) + log_extra_field("global_rank", get_global_rank()) + else: + log_extra_field("node_rank", 0) + log_extra_field("local_rank", 0) + log_extra_field("global_rank", 0) + + old_log_record_factory = logging.getLogRecordFactory() + + def log_record_factory(*args, **kwargs) -> logging.LogRecord: + record = old_log_record_factory(*args, **kwargs) + for field_name, field_value in _log_extra_fields.items(): + setattr(record, field_name, field_value) + return record + + logging.setLogRecordFactory(log_record_factory) + + handler: logging.Handler + if ( + os.environ.get("OLMo_NONINTERACTIVE", False) + or os.environ.get("DEBIAN_FRONTEND", None) == "noninteractive" + or not sys.stdout.isatty() + ): + handler = logging.StreamHandler(sys.stdout) + formatter = logging.Formatter( + "%(asctime)s\t%(hostname)s:%(local_rank)s\t%(name)s:%(lineno)s\t%(levelname)s\t%(message)s" + ) + formatter.default_time_format = "%Y-%m-%d %H:%M:%S" + formatter.default_msec_format = "%s.%03d" + handler.setFormatter(formatter) + else: + handler = RichHandler() + + def rank0_filter(record: logging.LogRecord) -> int: + if record.levelno > logging.INFO: + return 1 + if getattr(record, "global_rank", 0) == 0: + return 1 + else: + return 0 + + def local_rank0_filter(record: logging.LogRecord) -> int: + if record.levelno > logging.INFO: + return 1 + if getattr(record, "local_rank", 0) == 0: + return 1 + else: + return 0 + + if log_filter_type == LogFilterType.rank0_only: + filter = rank0_filter + elif log_filter_type == LogFilterType.local_rank0_only: + filter = local_rank0_filter # type: ignore + elif log_filter_type == LogFilterType.all_ranks: + filter = None + else: + raise ValueError(log_filter_type) + + if filter is not None: + handler.addFilter(filter) # type: ignore + logging.basicConfig(handlers=[handler], level=logging.INFO) + + logging.captureWarnings(True) + logging.getLogger("urllib3").setLevel(logging.ERROR) + + +def excepthook(exctype, value, traceback): + """ + Used to patch `sys.excepthook` in order to log exceptions. + """ + if issubclass(exctype, KeyboardInterrupt): + sys.__excepthook__(exctype, value, traceback) + elif issubclass(exctype, OLMoCliError): + rich.get_console().print(f"[yellow]{value}[/]", highlight=False) + elif issubclass(exctype, OLMoError): + rich.get_console().print(Text(f"{exctype.__name__}:", style="red"), value, highlight=False) + else: + log.critical("Uncaught %s: %s", exctype.__name__, value, exc_info=(exctype, value, traceback)) + + +def install_excepthook(): + sys.excepthook = excepthook + + +def filter_warnings(): + # Filter internal deprecation warnings from torch + warnings.filterwarnings( + action="ignore", + category=UserWarning, + message="torch.distributed.*_base is a private function and will be deprecated.*", + ) + warnings.filterwarnings( + action="ignore", + category=UserWarning, + message="TypedStorage is deprecated.*", + ) + warnings.filterwarnings( + action="ignore", + category=UserWarning, + message="Please use DTensor instead.*", + ) + # Torchvision warnings. We don't actually use torchvision. + warnings.filterwarnings( + action="ignore", + message="failed to load.*", + module="torchvision.io.image", + ) + + +def set_env_variables(): + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def prepare_cli_environment(log_filter_type: Optional[LogFilterType] = None): + if log_filter_type is None: + log_filter_type = LogFilterType(os.environ.get("LOG_FILTER_TYPE", "rank0_only")) + rich.reconfigure(width=max(rich.get_console().width, 180), soft_wrap=True) + setup_logging(log_filter_type=log_filter_type) + install_excepthook() + filter_warnings() + set_env_variables() + + +def clean_opt(arg: str) -> str: + if "=" not in arg: + arg = f"{arg}=True" + name, val = arg.split("=", 1) + name = name.strip("-").replace("-", "_") + return f"{name}={val}" + + +class RichHandler(logging.Handler): + """ + A simplified version of rich.logging.RichHandler from + https://github.com/Textualize/rich/blob/master/rich/logging.py + """ + + def __init__( + self, + *, + level: Union[int, str] = logging.NOTSET, + console: Optional[Console] = None, + markup: bool = False, + ) -> None: + super().__init__(level=level) + self.console = console or rich.get_console() + self.highlighter = NullHighlighter() + self.markup = markup + + def emit(self, record: logging.LogRecord) -> None: + try: + if hasattr(record.msg, "__rich__") or hasattr(record.msg, "__rich_console__"): + self.console.print(record.msg) + else: + msg: Any = record.msg + if isinstance(record.msg, str): + msg = self.render_message(record=record, message=record.getMessage()) + renderables = [ + self.get_time_text(record), + self.get_level_text(record), + self.get_location_text(record), + msg, + ] + if record.exc_info is not None: + tb = Traceback.from_exception(*record.exc_info) # type: ignore + renderables.append(tb) + self.console.print(*renderables) + except Exception: + self.handleError(record) + + def render_message(self, *, record: logging.LogRecord, message: str) -> ConsoleRenderable: + use_markup = getattr(record, "markup", self.markup) + message_text = Text.from_markup(message) if use_markup else Text(message) + + highlighter = getattr(record, "highlighter", self.highlighter) + if highlighter: + message_text = highlighter(message_text) + + return message_text + + def get_time_text(self, record: logging.LogRecord) -> Text: + log_time = datetime.fromtimestamp(record.created) + time_str = log_time.strftime("[%Y-%m-%d %X]") + return Text(time_str, style="log.time", end=" ") + + def get_level_text(self, record: logging.LogRecord) -> Text: + level_name = record.levelname + level_text = Text.styled(level_name.ljust(8), f"logging.level.{level_name.lower()}") + level_text.style = "log.level" + level_text.end = " " + return level_text + + def get_location_text(self, record: logging.LogRecord) -> Text: + name_and_line = f"{record.name}:{record.lineno}" if record.name != "root" else "root" + text = f"[{name_and_line}, rank={record.local_rank}]" # type: ignore + return Text(text, style="log.path") + + +def wait_for(condition: Callable[[], bool], description: str, timeout: float = 10.0): + """Wait for the condition function to return True.""" + start_time = time.monotonic() + while not condition(): + time.sleep(0.5) + if time.monotonic() - start_time > timeout: + raise TimeoutError(f"{description} timed out") + + +def is_url(path: PathOrStr) -> bool: + return re.match(r"[a-z0-9]+://.*", str(path)) is not None + + +def dir_is_empty(dir: PathOrStr) -> bool: + dir = Path(dir) + if not dir.is_dir(): + return True + try: + next(dir.glob("*")) + return False + except StopIteration: + return True + + +def get_progress_bar() -> Progress: + from cached_path import get_download_progress + + return get_download_progress() + + +def resource_path( + folder: PathOrStr, fname: str, local_cache: Optional[PathOrStr] = None, progress: Optional[Progress] = None +) -> Path: + if local_cache is not None and (local_path := Path(local_cache) / fname).is_file(): + log.info(f"Found local cache of {fname} at {local_path}") + return local_path + else: + from cached_path import cached_path + + return cached_path(f"{str(folder).rstrip('/')}/{fname}", progress=progress) + + +def file_size(path: PathOrStr) -> int: + """ + Get the size of a local or remote file in bytes. + """ + if is_url(path): + from urllib.parse import urlparse + + parsed = urlparse(str(path)) + if parsed.scheme == "gs": + return _gcs_file_size(parsed.netloc, parsed.path.strip("/")) + elif parsed.scheme in ("s3", "r2", "weka"): + return _s3_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/")) + elif parsed.scheme in ("http", "https"): + return _http_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/")) + elif parsed.scheme == "file": + return file_size(str(path).replace("file://", "", 1)) + else: + raise NotImplementedError(f"file size not implemented for '{parsed.scheme}' files") + else: + return os.stat(path).st_size + + +def upload(source: PathOrStr, target: str, save_overwrite: bool = False): + """Upload source file to a target location on GCS or S3.""" + from urllib.parse import urlparse + + source = Path(source) + assert source.is_file() + parsed = urlparse(target) + if parsed.scheme == "gs": + _gcs_upload(source, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite) + elif parsed.scheme in ("s3", "r2", "weka"): + _s3_upload(source, parsed.scheme, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite) + else: + raise NotImplementedError(f"Upload not implemented for '{parsed.scheme}' scheme") + + +def get_bytes_range(source: PathOrStr, bytes_start: int, num_bytes: int) -> bytes: + if is_url(source): + from urllib.parse import urlparse + + parsed = urlparse(str(source)) + if parsed.scheme == "gs": + return _gcs_get_bytes_range(parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes) + elif parsed.scheme in ("s3", "r2", "weka"): + return _s3_get_bytes_range( + parsed.scheme, parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes + ) + elif parsed.scheme in ("http", "https"): + return _http_get_bytes_range( + parsed.scheme, parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes + ) + elif parsed.scheme == "file": + return get_bytes_range(str(source).replace("file://", "", 1), bytes_start, num_bytes) + else: + raise NotImplementedError(f"get bytes range not implemented for '{parsed.scheme}' files") + else: + with open(source, "rb") as f: + f.seek(bytes_start) + return f.read(num_bytes) + + +def find_latest_checkpoint(dir: PathOrStr) -> Optional[PathOrStr]: + if is_url(dir): + from urllib.parse import urlparse + + parsed = urlparse(str(dir)) + if parsed.scheme == "gs": + raise NotImplementedError + elif parsed.scheme in ("s3", "r2", "weka"): + return _s3_find_latest_checkpoint(parsed.scheme, parsed.netloc, parsed.path.strip("/")) + elif parsed.scheme == "file": + return find_latest_checkpoint(str(dir).replace("file://", "", 1)) + else: + raise NotImplementedError(f"find_latest_checkpoint not implemented for '{parsed.scheme}' files") + else: + latest_step = 0 + latest_checkpoint: Optional[Path] = None + for path in Path(dir).glob("step*"): + if path.is_dir(): + try: + step = int(path.name.replace("step", "").replace("-unsharded", "")) + except ValueError: + continue + # We prioritize sharded checkpoints over unsharded checkpoints. + if step > latest_step or (step == latest_step and not path.name.endswith("-unsharded")): + latest_step = step + latest_checkpoint = path + return latest_checkpoint + + +def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False): + from google.cloud import storage as gcs + + storage_client = gcs.Client() + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(key) + if not save_overwrite and blob.exists(): + raise FileExistsError(f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it.") + blob.upload_from_filename(source) + + +def _gcs_file_size(bucket_name: str, key: str) -> int: + from google.api_core.exceptions import NotFound + from google.cloud import storage as gcs + + storage_client = gcs.Client() + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(key) + try: + blob.reload() + except NotFound: + raise FileNotFoundError(f"gs://{bucket_name}/{key}") + assert blob.size is not None + return blob.size + + +def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes: int) -> bytes: + from google.api_core.exceptions import NotFound + from google.cloud import storage as gcs + + storage_client = gcs.Client() + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(key) + try: + blob.reload() + except NotFound: + raise FileNotFoundError(f"gs://{bucket_name}/{key}") + return blob.download_as_bytes(start=bytes_start, end=bytes_start + num_bytes - 1) + + +def _get_s3_profile_name(scheme: str) -> Optional[str]: + if scheme == "s3": + # For backwards compatibility, we assume S3 uses the default profile if S3_PROFILE is not set. + return os.environ.get("S3_PROFILE") + if scheme == "r2": + profile_name = os.environ.get("R2_PROFILE") + if profile_name is None: + raise OLMoEnvironmentError( + "R2 profile name is not set. Did you forget to set the 'R2_PROFILE' env var?" + ) + + return profile_name + if scheme == "weka": + profile_name = os.environ.get("WEKA_PROFILE") + if profile_name is None: + raise OLMoEnvironmentError( + "Weka profile name is not set. Did you forget to set the 'WEKA_PROFILE' env var?" + ) + + return profile_name + + raise NotImplementedError(f"Cannot get profile name for scheme {scheme}") + + +def _get_s3_endpoint_url(scheme: str) -> Optional[str]: + if scheme == "s3": + return None + if scheme == "r2": + r2_endpoint_url = os.environ.get("R2_ENDPOINT_URL") + if r2_endpoint_url is None: + raise OLMoEnvironmentError( + "R2 endpoint url is not set. Did you forget to set the 'R2_ENDPOINT_URL' env var?" + ) + + return r2_endpoint_url + if scheme == "weka": + weka_endpoint_url = os.environ.get("WEKA_ENDPOINT_URL") + if weka_endpoint_url is None: + raise OLMoEnvironmentError( + "Weka endpoint url is not set. Did you forget to set the 'WEKA_ENDPOINT_URL' env var?" + ) + + return weka_endpoint_url + + raise NotImplementedError(f"Cannot get endpoint url for scheme {scheme}") + + +@cache +def _get_s3_client(scheme: str): + session = boto3.Session(profile_name=_get_s3_profile_name(scheme)) + return session.client( + "s3", + endpoint_url=_get_s3_endpoint_url(scheme), + config=Config(retries={"max_attempts": 10, "mode": "standard"}), + use_ssl=not int(os.environ.get("OLMO_NO_SSL", "0")), + ) + + +def _wait_before_retry(attempt: int): + time.sleep(min(0.5 * 2**attempt, 3.0)) + + +def _s3_upload( + source: Path, scheme: str, bucket_name: str, key: str, save_overwrite: bool = False, max_attempts: int = 3 +): + err: Optional[Exception] = None + if not save_overwrite: + for attempt in range(1, max_attempts + 1): + try: + _get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key) + raise FileExistsError( + f"s3://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it." + ) + except boto_exceptions.ClientError as e: + if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404: + err = None + break + err = e + + if attempt < max_attempts: + log.warning("%s failed attempt %d with retriable error: %s", _s3_upload.__name__, attempt, err) + _wait_before_retry(attempt) + + if err is not None: + raise OLMoNetworkError(f"Failed to check object existence during {scheme} upload") from err + + try: + _get_s3_client(scheme).upload_file(source, bucket_name, key) + except boto_exceptions.ClientError as e: + raise OLMoNetworkError(f"Failed to upload to {scheme}") from e + + +def _s3_file_size(scheme: str, bucket_name: str, key: str, max_attempts: int = 3) -> int: + err: Optional[Exception] = None + for attempt in range(1, max_attempts + 1): + try: + return _get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)["ContentLength"] + except boto_exceptions.ClientError as e: + if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404: + raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e + err = e + + if attempt < max_attempts: + log.warning("%s failed attempt %d with retriable error: %s", _s3_file_size.__name__, attempt, err) + _wait_before_retry(attempt) + + raise OLMoNetworkError(f"Failed to get {scheme} file size") from err + + +def _s3_get_bytes_range( + scheme: str, bucket_name: str, key: str, bytes_start: int, num_bytes: int, max_attempts: int = 3 +) -> bytes: + err: Optional[Exception] = None + for attempt in range(1, max_attempts + 1): + try: + return ( + _get_s3_client(scheme) + .get_object( + Bucket=bucket_name, Key=key, Range=f"bytes={bytes_start}-{bytes_start + num_bytes - 1}" + )["Body"] + .read() + ) + except boto_exceptions.ClientError as e: + if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404: + raise FileNotFoundError(f"{scheme}://{bucket_name}/{key}") from e + err = e + except (boto_exceptions.HTTPClientError, boto_exceptions.ConnectionError) as e: + # ResponseStreamingError (subclass of HTTPClientError) can happen as + # a result of a failed read from the stream (http.client.IncompleteRead). + # Retrying can help in this case. + err = e + + if attempt < max_attempts: + log.warning( + "%s failed attempt %d with retriable error: %s", _s3_get_bytes_range.__name__, attempt, err + ) + _wait_before_retry(attempt) + + # When torch's DataLoader intercepts exceptions, it may try to re-raise them + # by recalling their constructor with a single message arg. Torch has some + # logic to deal with the absence of a single-parameter constructor, but it + # doesn't gracefully handle other possible failures in calling such a constructor + # This can cause an irrelevant exception (e.g. KeyError: 'error'), resulting + # in us losing the true exception info. To avoid this, we change the exception + # to a type that has a single-parameter constructor. + raise OLMoNetworkError(f"Failed to get bytes range from {scheme}") from err + + +def _s3_find_latest_checkpoint(scheme: str, bucket_name: str, prefix: str) -> Optional[str]: + if not prefix.endswith("/"): + prefix = f"{prefix}/" + response = _get_s3_client(scheme).list_objects(Bucket=bucket_name, Prefix=prefix, Delimiter="/") + assert not response["IsTruncated"] # need to handle this if it happens + latest_step = 0 + latest_checkpoint: Optional[str] = None + for item in response["CommonPrefixes"]: + prefix = item["Prefix"].strip("/") + checkpoint_name = os.path.split(prefix)[-1] + if not checkpoint_name.startswith("step"): + continue + try: + step = int(checkpoint_name.replace("step", "").replace("-unsharded", "")) + except ValueError: + continue + # Make sure the checkpoint dir contains a config, otherwise the checkpoint is incomplete + # (upload might have have failed part way through). + try: + _s3_file_size(scheme, bucket_name, f"{prefix}/config.yaml") + except FileNotFoundError: + continue + # We prioritize sharded checkpoints over unsharded ones. + if step > latest_step or (step == latest_step and not checkpoint_name.endswith("-unsharded")): + latest_step = step + latest_checkpoint = f"{scheme}://{bucket_name}/{prefix}" + return latest_checkpoint + + +def _http_file_size(scheme: str, host_name: str, path: str) -> int: + import requests + + response = requests.head(f"{scheme}://{host_name}/{path}", allow_redirects=True) + return int(response.headers.get("content-length")) + + +def _http_get_bytes_range(scheme: str, host_name: str, path: str, bytes_start: int, num_bytes: int) -> bytes: + import requests + + response = requests.get( + f"{scheme}://{host_name}/{path}", headers={"Range": f"bytes={bytes_start}-{bytes_start+num_bytes-1}"} + ) + result = response.content + assert ( + len(result) == num_bytes + ), f"expected {num_bytes} bytes, got {len(result)}" # Some web servers silently ignore range requests and send everything + return result + + +def default_thread_count() -> int: + return int(os.environ.get("OLMO_NUM_THREADS") or min(32, (os.cpu_count() or 1) + 4)) + + +def pass_through_fn(fn, *args, **kwargs): + return fn(*args, **kwargs) + + +def threaded_generator(g, maxsize: int = 16, thread_name: Optional[str] = None): + q: Queue = Queue(maxsize=maxsize) + + sentinel = object() + + def fill_queue(): + try: + for value in g: + q.put(value) + except Exception as e: + q.put(e) + finally: + q.put(sentinel) + + thread_name = thread_name or repr(g) + thread = Thread(name=thread_name, target=fill_queue, daemon=True) + thread.start() + + for x in iter(q.get, sentinel): + if isinstance(x, Exception): + raise OLMoThreadError(f"generator thread {thread_name} failed") from x + else: + yield x + + +def roundrobin(*iterables): + """ + Call the given iterables in a round-robin fashion. For example: + ``roundrobin('ABC', 'D', 'EF') --> A D E B F C`` + """ + # Adapted from https://docs.python.org/3/library/itertools.html#itertools-recipes + num_active = len(iterables) + nexts = cycle(iter(it).__next__ for it in iterables) + while num_active: + try: + for next in nexts: + yield next() + except StopIteration: + # Remove the iterator we just exhausted from the cycle. + num_active -= 1 + nexts = cycle(islice(nexts, num_active)) + + +def add_cached_path_clients(): + add_scheme_client(WekaClient) + + +class WekaClient(SchemeClient): + recoverable_errors = SchemeClient.recoverable_errors + ( + boto_exceptions.HTTPClientError, + boto_exceptions.ConnectionError, + ) + + scheme = "weka" + + def __init__(self, resource: str) -> None: + SchemeClient.__init__(self, resource) + self.bucket_name, self.path = WekaClient._split_cloud_path(resource, "weka") + self.s3 = _get_s3_client("weka") + self.object_info = None + + @staticmethod + def _split_cloud_path(url: str, provider: str) -> Tuple[str, str]: + """Split a full s3 path into the bucket name and path.""" + from urllib.parse import urlparse + + parsed = urlparse(url) + if not parsed.netloc or not parsed.path: + raise ValueError("bad {} path {}".format(provider, url)) + bucket_name = parsed.netloc + provider_path = parsed.path + # Remove '/' at beginning of path. + if provider_path.startswith("/"): + provider_path = provider_path[1:] + return bucket_name, provider_path + + def _ensure_object_info(self): + if self.object_info is None: + try: + self.object_info = self.s3.head_object(Bucket=self.bucket_name, Key=self.path) + except boto_exceptions.ClientError as e: + if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404: + raise FileNotFoundError(f"weka://{self.bucket_name}/{self.path}") from e + raise e + + def get_etag(self) -> Optional[str]: + self._ensure_object_info() + assert self.object_info is not None + return self.object_info.get("ETag") + + def get_size(self) -> Optional[int]: + self._ensure_object_info() + assert self.object_info is not None + return self.object_info.get("ContentLength") + + def get_resource(self, temp_file: io.BufferedWriter) -> None: + self.s3.download_fileobj(Fileobj=temp_file, Bucket=self.bucket_name, Key=self.path) + + def get_bytes_range(self, index: int, length: int) -> bytes: + response = self.s3.get_object( + Bucket=self.bucket_name, Key=self.path, Range=f"bytes={index}-{index+length-1}" + ) + return response["Body"].read() \ No newline at end of file diff --git a/3.test_cases/neuronx-distributed/olmo/run_zero.py b/3.test_cases/neuronx-distributed/olmo/run_zero.py new file mode 100644 index 00000000..25ef66f7 --- /dev/null +++ b/3.test_cases/neuronx-distributed/olmo/run_zero.py @@ -0,0 +1,77 @@ +import os +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +import torch.nn.functional as F +import torch_xla.core.xla_model as xm +# XLA imports for parallel loader and multi-processing +import torch_xla.distributed.parallel_loader as pl +from torch.utils.data.distributed import DistributedSampler +import torch_xla.distributed.xla_backend +from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer + + +from olmo.config import ModelConfig, TrainConfig, TokenizerConfig +from olmo.datasets import SortDataset +from olmo.model import OLMo +from olmo.tokenizer import Tokenizer + +# Set Neuron SDK environment variables +os.environ["XLA_USE_BF16"] = "1" + + +# Initialize XLA process group for torchrun +torch.distributed.init_process_group('xla') +device = "xla" +# XLA MP: get world size +world_size = xm.xrt_world_size() +olmo_config = TrainConfig.load("./configs/OLMo-7B.yaml") +olmo_config.model.init_device = "cpu" +tokenizer = Tokenizer.from_train_config(olmo_config) +model = OLMo(olmo_config.model) +model = model.to(device) +print("Loaded olmo model") +# Define the batch size and sequence length +batch_size = 1 +# create train and test dataset +train_dataset = SortDataset('train') +test_dataset = SortDataset('test') +train_sampler = DistributedSampler( + train_dataset, + num_replicas=world_size, + rank=xm.get_ordinal(), + shuffle=True +) +train_loader = DataLoader( + train_dataset, + batch_size=1, + sampler=train_sampler +) +# We wrap the dataloader with MpDeviceLoader. This dataloader should take +# care of copying the tensors to device +train_loader = pl.MpDeviceLoader(train_loader, device) +#optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) +optimizer = ZeroRedundancyOptimizer( + model.parameters(), torch.optim.Adam, + lr=1e-4, pin_layout=False +) + +model.train() +pbar = tqdm(train_loader) +for idx, (x, y) in enumerate(pbar): + optimizer.zero_grad() + # forward the model + logits = model(x) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + y.view(-1), + ignore_index=-1 + ) + # backprop and update the parameters + loss.backward() + optimizer.step() + #xm.optimizer_step(optimizer) # XLA MP: performs grad allreduce and optimizer step + pbar.set_description(f"Iteration: {idx}, train loss: {loss.item():.5f}") + +# XLA: use xm.save instead of torch.save to ensure states are moved back to cpu +xm.save(model.state_dict(), "model.pt") diff --git a/3.test_cases/neuronx-distributed/olmo/run_zero.sbatch b/3.test_cases/neuronx-distributed/olmo/run_zero.sbatch new file mode 100644 index 00000000..4e163edb --- /dev/null +++ b/3.test_cases/neuronx-distributed/olmo/run_zero.sbatch @@ -0,0 +1,6 @@ +#!/bin/bash +#SBATCH --nodes=4 +#SBATCH --exclusive +#SBATCH --output=logs/slurm-%x-%j.out + +srun neuron_parallel_compile ./run_zero.sh MIXED \ No newline at end of file diff --git a/3.test_cases/neuronx-distributed/olmo/run_zero.sh b/3.test_cases/neuronx-distributed/olmo/run_zero.sh new file mode 100755 index 00000000..b7b71216 --- /dev/null +++ b/3.test_cases/neuronx-distributed/olmo/run_zero.sh @@ -0,0 +1,67 @@ +#!/bin/bash +set -o pipefail + +sudo rmmod neuron; sudo modprobe neuron +sudo sysctl -w net.ipv4.ip_local_reserved_ports=44000,48620 +sudo sysctl -w kernel.threads-max=10000000 +ulimit -c unlimited + +NUM_NEURONCORES=32 +DISTRIBUTED_ARGS="--nproc_per_node $NUM_NEURONCORES" + +LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4" +MALLOC_ARENA_MAX=64 +echo "MALLOC_ARENA_MAX" $MALLOC_ARENA_MAX +echo "LD_PRELOAD" $LD_PRELOAD + +if [ ! -z "$SLURM_NTASKS" ]; then + # if running inside slurm, handle here + MASTER_ADDR=(`scontrol show hostnames $SLURM_JOB_NODELIST`) + MASTER_PORT=2022 + WORLD_SIZE_JOB=$SLURM_NTASKS + RANK_NODE=$SLURM_NODEID + JOB_ID_TAG=job-"$SLURM_JOB_ID" + DISTRIBUTED_ARGS="--nproc_per_node $NUM_NEURONCORES --nnodes $WORLD_SIZE_JOB --node_rank $RANK_NODE --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + echo $DISTRIBUTED_ARGS + export NEURON_RT_ROOT_COMM_ID=$MASTER_ADDR:46820 + export FI_EFA_FORK_SAFE=1 + export FI_EFA_USE_DEVICE_RDMA=1 + export FI_PROVIDER=efa + echo "WORLD_SIZE_JOB=$WORLD_SIZE_JOB, RANK_NODE=$RANK_NODE, MASTER_ADDR_JOB=$MASTER_ADDR_JOB, NODE_LIST=$NODE_LIST" + export TRANSFORMERS_CACHE=$HOME/hf_cache/`hostname`/hub + export HF_DATASETS_CACHE=$HOME/hf_cache/`hostname`/datasets +fi + +#Print Slurm Config +date;hostname; + +export TRAINING_PRECISION=$1 #options FP32, BF16, MIXED +export NEURON_RT_STOCHASTIC_ROUNDING_EN=1 + +if [[ "BF16" == $TRAINING_PRECISION ]]; then + echo "USING BF16 ONLY" + export XLA_USE_BF16=1 + export NEURON_CC_FLAGS="--retry_failed_compilation --distribution-strategy llm-training --model-type transformer" +elif [[ "MIXED" == $TRAINING_PRECISION ]]; then + echo "USING MIXED PRECISION BF16 and FP32" + export NEURON_CC_FLAGS="--retry_failed_compilation --enable-mixed-precision-accumulation --distribution-strategy llm-training --model-type transformer" +else + echo "USING FP32 as default" + export NEURON_CC_FLAGS="--retry_failed_compilation --distribution-strategy llm-training --model-type transformer" +fi + +NEURON_CC_FLAGS+=" --cache_dir=$HOME/neuron_cache/gpt_1p5B/`hostname`" + +export DISABLE_NUMERIC_CC_TOKEN=1 +export NEURON_RT_HIERARCHICAL_CC=1 + +export NEURON_RT_EXEC_TIMEOUT=600 +export TF_NUM_INTEROP_THREADS=8192 + +export NEURON_ENABLE_NOSEED_DROPOUT=1 + +GRAD_ACCUM_STEP=1 +BATCH_SIZE=1 + +torchrun $DISTRIBUTED_ARGS run_zero.py \ + |& tee $LOG_FILE_NAME \ No newline at end of file diff --git a/3.test_cases/neuronx-distributed/olmo/run_zero_compile.sbatch b/3.test_cases/neuronx-distributed/olmo/run_zero_compile.sbatch new file mode 100644 index 00000000..4e163edb --- /dev/null +++ b/3.test_cases/neuronx-distributed/olmo/run_zero_compile.sbatch @@ -0,0 +1,6 @@ +#!/bin/bash +#SBATCH --nodes=4 +#SBATCH --exclusive +#SBATCH --output=logs/slurm-%x-%j.out + +srun neuron_parallel_compile ./run_zero.sh MIXED \ No newline at end of file diff --git a/3.test_cases/neuronx-distributed/olmo/tokenizers/.gitignore b/3.test_cases/neuronx-distributed/olmo/tokenizers/.gitignore new file mode 100644 index 00000000..94a2dd14 --- /dev/null +++ b/3.test_cases/neuronx-distributed/olmo/tokenizers/.gitignore @@ -0,0 +1 @@ +*.json \ No newline at end of file diff --git a/3.test_cases/neuronx-distributed/olmo/train_neuron.py b/3.test_cases/neuronx-distributed/olmo/train_neuron.py new file mode 100644 index 00000000..962b6a74 --- /dev/null +++ b/3.test_cases/neuronx-distributed/olmo/train_neuron.py @@ -0,0 +1,59 @@ +import os +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +import torch.nn.functional as F +import torch_xla.core.xla_model as xm +# XLA imports for parallel loader and multi-processing +import torch_xla.distributed.parallel_loader as pl +from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer + + +from olmo.config import ModelConfig, TrainConfig, TokenizerConfig +from olmo.datasets import SortDataset +from olmo.model import OLMo +from olmo.tokenizer import Tokenizer + +# Set Neuron SDK environment variables +os.environ["XLA_USE_BF16"] = "1" + + +device = "xla" +olmo_config = TrainConfig.load("./configs/OLMo-7B.yaml") +olmo_config.model.init_device = "cpu" +#tokenizer = Tokenizer.from_train_config(olmo_config) +model = OLMo(olmo_config.model) +model = model.to(device) +print("Loaded olmo model") +# Define the batch size and sequence length +batch_size = 4 +# create train and test dataset +train_dataset = SortDataset('train') +test_dataset = SortDataset('test') +train_loader = DataLoader( + train_dataset, + batch_size=batch_size, +) +# We wrap the dataloader with MpDeviceLoader. This dataloader should take +# care of copying the tensors to device +train_loader = pl.MpDeviceLoader(train_loader, device) +optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) +model.train() + +pbar = tqdm(train_loader) +for idx, (x, y) in enumerate(pbar): + optimizer.zero_grad() + # forward the model + logits = model(x) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + y.view(-1), + ignore_index=-1 + ) + # backprop and update the parameters + loss.backward() + xm.optimizer_step(optimizer) # XLA MP: performs grad allreduce and optimizer step + pbar.set_description(f"Iteration: {idx}, train loss: {loss.item():.5f}") + +# XLA: use xm.save instead of torch.save to ensure states are moved back to cpu +xm.save(model.state_dict(), "model.pt") \ No newline at end of file diff --git a/3.test_cases/neuronx-distributed/olmo/train_zero.py b/3.test_cases/neuronx-distributed/olmo/train_zero.py new file mode 100644 index 00000000..6aab6fc3 --- /dev/null +++ b/3.test_cases/neuronx-distributed/olmo/train_zero.py @@ -0,0 +1,64 @@ +import os +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +import torch.nn.functional as F +import torch_xla.core.xla_model as xm +# XLA imports for parallel loader and multi-processing +import torch_xla.distributed.parallel_loader as pl +from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer + + +from olmo.config import ModelConfig, TrainConfig, TokenizerConfig +from olmo.datasets import SortDataset +from olmo.model import OLMo +from olmo.tokenizer import Tokenizer + +# Set Neuron SDK environment variables +os.environ["XLA_USE_BF16"] = "1" + + +device = "xla" +olmo_config = TrainConfig.load("./configs/OLMo-1B.yaml") +olmo_config.model.init_device = "cpu" +#tokenizer = Tokenizer.from_train_config(olmo_config) +model = OLMo(olmo_config.model) +model = model.to(device) +print("Loaded olmo model") +# Define the batch size and sequence length +batch_size = 4 +# create train and test dataset +train_dataset = SortDataset('train') +test_dataset = SortDataset('test') +train_loader = DataLoader( + train_dataset, + batch_size=batch_size, +) +# We wrap the dataloader with MpDeviceLoader. This dataloader should take +# care of copying the tensors to device +train_loader = pl.MpDeviceLoader(train_loader, device) +#optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) +optimizer = ZeroRedundancyOptimizer( + model.parameters(), torch.optim.Adam, + lr=1e-4, pin_layout=False +) +model.train() + +pbar = tqdm(train_loader) +for idx, (x, y) in enumerate(pbar): + optimizer.zero_grad() + # forward the model + logits = model(x) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + y.view(-1), + ignore_index=-1 + ) + # backprop and update the parameters + loss.backward() + optimizer.step() + #xm.optimizer_step(optimizer) # XLA MP: performs grad allreduce and optimizer step + pbar.set_description(f"Iteration: {idx}, train loss: {loss.item():.5f}") + +# XLA: use xm.save instead of torch.save to ensure states are moved back to cpu +xm.save(model.state_dict(), "model.pt")