Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neuron distributed #359

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions 3.test_cases/neuronx-distributed/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# NeuronX distributed test cases <!-- omit in toc -->



MPT are GPT-style models in [llm-foundry](https://github.com/mosaicml/llm-foundry/tree/main) with some special features -- [Flash Attention](https://arxiv.org/abs/2205.14135) for efficiency, [ALiBi](https://arxiv.org/abs/2108.12409) for context length extrapolation, and stability improvements to mitigate loss spikes.

This project contains:

* AWS optimized [llm-foundry](https://github.com/mosaicml/llm-foundry/tree/main) container image.
* Slurm scripts for the [c4 dataset](https://huggingface.co/datasets/c4) preparation and multi-node distributed training.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a prerequisites section? You need a 2 node Trn1 PC? What would I need to change if I am doing this on SMHP?

## 1. Preparation

This guide assumes that you have the following:

* A functional Slurm cluster on AWS.
* Docker, [Pyxis](https://github.com/NVIDIA/pyxis) and [Enroot](https://github.com/NVIDIA/enroot) installed.
* An FSx for Lustre filesystem mounted on `/fsx`.

We recommend that you setup a Slurm cluster using the templates in the architectures [directory](../../1.architectures). Before creating the Slurm cluster, you need to setup the following environment variables:

```bash
export APPS_PATH=/apps
export ENROOT_IMAGE=$APPS_PATH/llm-foundry.sqsh
export FSX_PATH=/fsx
export DATA_PATH=$FSX_PATH/c4-dataset
export TEST_CASE_PATH=${HOME}/3.MPT # where you copy the test case or set to your test case path
cd $TEST_CASE_PATH
```

then follow the detailed instructions [here](../../1.architectures/2.aws-parallelcluster/README.md).

## 2. Build the container

Before running training jobs, you need to use an [Enroot](https://github.com/NVIDIA/enroot) container to retrieve and preprocess the input data. Below are the steps you need to follow:

1. Copy the test case files to your cluster. You will need `0.llm-foundry.Dockerfile`,
2. Build the Docker image with the command below in this directory.

```bash
docker build -t llm-foundry -f 0.llm-foundry.Dockerfile .
```

3. Once the Docker image is built, you can check if it is present with `docker images`. You should see an output similar to this one:

```bash
REPOSITORY TAG IMAGE ID CREATED SIZE
llm-foundry latest a964fb32cd53 2 weeks ago 23.6GB
...
```

4. Convert the Docker image to a squash file with the command below.

```bash
enroot import -o ${ENROOT_IMAGE} dockerd://llm-foundry:latest
```

The file will be stored in the `/apps` directory (default). The output should look as below.

```bash
[INFO] Fetching image

36a8c752c28a2db543d2a632a3fc1fcbd5789a6f3d45b9d3a24632420dedcfa8

[INFO] Extracting image content...
[INFO] Creating squashfs filesystem...

Parallel mksquashfs: Using 32 processors
Creating 4.0 filesystem on /apps/llm-foundry.sqsh, block size 131072.
[========================================================================================================================================================================================================================-] 291068/291068 100%

Exportable Squashfs 4.0 filesystem, gzip compressed, data block size 131072
uncompressed data, uncompressed metadata, uncompressed fragments, uncompressed xattrs
duplicates are not removed
...
```

It will take around 5 minutes to convert the container image from Docker to the Enroot format. Once done proceed to the next stage.

For ease of testing we've included a `Makefile` that automatically builds and imports the latest image. To run this, execute `make` or you can individually specify `make build` to build the Docker image, `make clean` to remove the squash file and `make import` to import the Dockerfile into enroot squash file.

## 3. Run the processing job

You need to retrieve input data and preprocess it before running the training job.

1. Run a preprocessing job by submitting the script `1.c4-preprocess.sbatch` to Slurm. The command will return the Slurm Job ID. You can use `squeue` to consult the status of your jobs.

```bash
sbatch 1.c4-preprocess.sbatch
```

It will create the streaming dataset for composer library using C4 dataset in `/fsx/c4-dataset` (default).

2. You see a new file in your current working directory called `c4-preprocess_XY.out` where `XY` corresponds the Slurm job ID. This is your output file and will capture the `STDOUT` and `STDERR` from your job. You can check how it progresses via the command `tail -f c4-preprocess_XY.out` with the correct job ID instead of `XY`. If running successfully, the job will generate an output similar to the except below.

```console
Downloading (…)okenizer_config.json: 100%|██████████| 156/156 [00:00<00:00, 1.09MB/s]
...
Downloading metadata: 100%|██████████| 2.40M/2.40M [00:01<00:00, 2.05MB/s]
...
train_small: 32%|███▏ | 31745/100000 [01:51<00:19, 3538.83it/s]
...
val_small: 100%|██████████| 10000/10000 [00:19<00:00, 514.19it/s]
```

Please be aware that this job downloads the tokenizer on demand (if it's not available under `./EleutherAI/gpt-neox-20b`), after which the tokenizer will be cached under `$HOME/.cache/huggingface`, and the `$HOME` directory is an NFS filesystem shared by the head node. Please consult the [HuggingFace cache management](https://huggingface.co/docs/datasets/cache) document to learn more about fine-grained control of the HuggingFace cache.

3. After the job completed, check `/fsx/c4-dataset` (default) which will contain a structure similar as below

```console
/fsx/c4-dataset/
├── train_small
│ ├── index.json
│ ├── shard.00000.mds
│ ├── shard.00001.mds
│ ├── shard.00002.mds
...
│ ├── shard.00023.mds
│ └── shard.00024.mds
└── val_small
├── index.json
├── shard.00000.mds
├── shard.00001.mds
└── shard.00002.mds
```

Once preprocessing is done, you will run a training job in the next stage.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not need to compile the model? Can we add a few sentences on why? The Llama3 example on Trn1 has a section on compiling the models. Maybe add something similar?

## 4. Distributed training of MPT

Now that the data is preprocessed, we will pretrain a MPT model with [Mosaic Composer](https://github.com/mosaicml/composer).

1. Run a training job by submitting script `2.train-mpt-manual-distributed.sbatch` to Slurm via `sbatch` as shown below.

```bash
sbatch 2.train-mpt-manual-distributed.sbatch
```
by default it runs `mpt-7b` model. You can specify model to be trained as:
```bash
sbatch 2.train-mpt-manual-distributed.sbatch mpt-30b
```

2. When the training job completes successfully, it should produce a log output similar to the below in the `logs/` directory of `$TEST_CASE_PATH`.

```console
...
0: [batch=1/300000000]:
0: Train time/epoch: 0
0: Train time/batch: 0
0: Train time/sample: 0
0: Train time/batch_in_epoch: 0
0: Train time/sample_in_epoch: 0
0: Train time/token: 0
0: Train time/token_in_epoch: 0
0: Train memory/allocated_mem: 3.6287
0: Train memory/active_mem: 3.6287
0: Train memory/inactive_mem: 2.7844
0: Train memory/reserved_mem: 20.9650
0: Train memory/alloc_retries: 0
0: Train trainer/device_train_microbatch_size: 8
0: Train loss/train/total: 12.0000
0: Train metrics/train/LanguageCrossEntropy: 12.0000
0: Train metrics/train/LanguagePerplexity: 162754.5000
0: Train time/train: 0.0037
0: Train time/val: 0.0000
...
```

## 5. Authors / Reviewers

* [A] Keita Watanabe - mlkeita@
* [R] Pierre-Yves Aquilanti - pierreya@
* [R] Verdi March - marcverd@
42 changes: 42 additions & 0 deletions 3.test_cases/neuronx-distributed/mingpt/0.cpu.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we using this file?

Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from torch.nn import functional as F
from mingpt.model import GPT
from mingpt.datasets import SortDataset
from mingpt.trainer import Trainer
from mingpt.configs import TrainConfig

# create train and test dataset
train_dataset = SortDataset('train')
test_dataset = SortDataset('test')
train_config = TrainConfig.get_default_config()
train_loader = DataLoader(
train_dataset,
batch_size=train_config.batch_size,
)

# create a GPT instance
model_config = GPT.get_default_config()
model_config.model_type = 'gpt-nano'
model_config.vocab_size = train_dataset.get_vocab_size()
model_config.block_size = train_dataset.get_block_size()
model = GPT(model_config)
optimizer = model.configure_optimizers(train_config)

model.train()
pbar = tqdm(train_loader)
for idx, (x, y) in enumerate(pbar):
optimizer.zero_grad()
# forward the model
logits = model(x)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
y.view(-1),
ignore_index=-1
)
# backprop and update the parameters
loss.backward()
optimizer.step()
pbar.set_description(f"Iteration: {idx}, train loss: {loss.item():.5f}")
53 changes: 53 additions & 0 deletions 3.test_cases/neuronx-distributed/mingpt/1.neuron.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.nn import functional as F
import torch_xla.core.xla_model as xm
# XLA imports for parallel loader and multi-processing
import torch_xla.distributed.parallel_loader as pl

from mingpt.model import GPT
from mingpt.datasets import SortDataset
from mingpt.trainer import Trainer
from mingpt.configs import TrainConfig



device = 'xla'
# create train and test dataset
train_dataset = SortDataset('train')
test_dataset = SortDataset('test')
train_config = TrainConfig.get_default_config()
train_loader = DataLoader(
train_dataset,
batch_size=train_config.batch_size,
)
# We wrap the dataloader with MpDeviceLoader. This dataloader should take
# care of copying the tensors to device
train_loader = pl.MpDeviceLoader(train_loader, device)

# create a GPT instance
model_config = GPT.get_default_config()
model_config.model_type = 'gpt-nano'
model_config.vocab_size = train_dataset.get_vocab_size()
model_config.block_size = train_dataset.get_block_size()
model = GPT(model_config)
model = model.to(device)
optimizer = model.configure_optimizers(train_config)


model.train()
pbar = tqdm(train_loader)
for idx, (x, y) in enumerate(pbar):
optimizer.zero_grad()
# forward the model
logits = model(x)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
y.view(-1),
ignore_index=-1
)
# backprop and update the parameters
loss.backward()
xm.optimizer_step(optimizer) # XLA MP: performs grad allreduce and optimizer step
pbar.set_description(f"Iteration: {idx}, train loss: {loss.item():.5f}")
7 changes: 7 additions & 0 deletions 3.test_cases/neuronx-distributed/mingpt/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
5 changes: 5 additions & 0 deletions 3.test_cases/neuronx-distributed/mingpt/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# minGPT test case <!-- omit in toc -->

This test case is an educational sample that guide you through how to construct distributed training codes using NeuronX distributed.


Loading