Skip to content

Commit

Permalink
Dataset from file (#146)
Browse files Browse the repository at this point in the history
Co-authored-by: Torsten Scholak <[email protected]>
Co-authored-by: Oleksiy Ostapenko <[email protected]>
  • Loading branch information
3 people authored Mar 4, 2025
1 parent 23006dc commit 6df65c1
Show file tree
Hide file tree
Showing 16 changed files with 648 additions and 101 deletions.
83 changes: 49 additions & 34 deletions docs/quick-start.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ Choose based on your goals for this tutorial.

For this tutorial, we'll use text from the [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/) dataset. This dataset is a free approximation of the WebText data OpenAI used for GPT-2, and it's perfect for our test run!

Create a configuration file for the dataset preparation. Copy the following content:
Create a configuration file for the dataset preparation.
Save the following as `./fast-llm-tutorial/prepare-config.yaml``:

=== "Small"

Expand All @@ -242,10 +243,15 @@ Create a configuration file for the dataset preparation. Copy the following cont

tokenizer:
path: fast-llm-tutorial/pretrained-model

splits: # (3)!
training: 0.9
validation: 0.1
```

1. Processing speed scales linearly with the number of CPUs.
2. This small dataset restricts to the first 10K records of the OpenWebText dataset to speed up the process. If you want to use the full dataset, replace with `openwebtext`.
3. 90% train, 10% validation. These settings need to be adjusted based on the size of your dataset.

=== "Big"

Expand All @@ -263,11 +269,14 @@ Create a configuration file for the dataset preparation. Copy the following cont

tokenizer:
path: fast-llm-tutorial/pretrained-model

splits: # (2)!
training: 0.99
validation: 0.01
```

1. Processing speed scales linearly with the number of CPUs.

Save it as `./fast-llm-tutorial/prepare-config.yaml`.
2. 99% train, 1% validation. These settings need to be adjusted based on the size of your dataset.

Fast-LLM ships with a `prepare` command that will download and preprocess the dataset for you.

Expand Down Expand Up @@ -498,33 +507,36 @@ Save the following as `fast-llm-tutorial/train-config.yaml`:
sequence_length: 1024
batch_size: 480 # (5)!
data:
format: file
path: fast-llm-tutorial/dataset/fast_llm_dataset.json # (6)!
split: [9, 1, 0] # (7)!
datasets:
Training:
type: file
path: fast-llm-tutorial/dataset/fast_llm_config_training.yaml # (6)!
Validation:
type: file
path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml # (6)!
optimizer:
learning_rate:
base: 6.0e-04
pretrained:
format: llama # (8)!
format: llama # (7)!
path: fast-llm-tutorial/pretrained-model
model_weights: no # (9)!
model_weights: no # (8)!
model:
base_model:
transformer:
use_flash_attention: yes # (10)!
use_flash_attention: yes # (9)!
distributed:
training_dtype: bf16 # (11)!
training_dtype: bf16 # (10)!
run:
experiment_dir: fast-llm-tutorial/experiment
```

1. For the small run, we'll stop after 100 iterations.
2. The trained model will be saved in `Transformers` Llama format to `fast-llm-tutorial/experiment/export/llama/100` at the end of the small run. You can also save as a `Fast-LLM` checkpoint by setting the `format` to `fast_llm`.
3. Entirely optional, but it's a good idea to track your training progress with Weights & Biases. Replace `null` with your own W&B entity name. If you don't want to use W&B, just ignore this section.
3. Adjust the number of sequences per GPU based on GPU memory. For SmolLM2-135M at 1024 sequenced length and a 80GB GPU, a `micro_batch_size` of 60 should work well.
4. Must be divisible by the number of GPUs and the `micro_batch_size`. At 1024 tokens per sequence, 480 corresponds to about 500,000 tokens per batch.
5. Location of the dataset metadata file generated in Step 4.
6. 90% train, 10% validation, 0% test. These settings need to be adjusted based on the size of your dataset.
4. Adjust the number of sequences per GPU based on GPU memory. For SmolLM2-135M at 1024 sequenced length and a 80GB GPU, a `micro_batch_size` of 60 should work well.
5. Must be divisible by the number of GPUs and the `micro_batch_size`. At 1024 tokens per sequence, 480 corresponds to about 500,000 tokens per batch.
6. Location of the dataset metadata files generated in Step 4.
7. Format of the pretrained model. Since SmolLM is a Llama model, we set this to `llama`.
8. We'll train SmolLM2-135M from scratch. You can set to `yes` to continue training from a checkpoint (if you put one in the model directory).
9. By default, Fast-LLM uses FlashAttention for faster training. If you're using Volta GPUs, set this to `no`.
Expand Down Expand Up @@ -556,32 +568,36 @@ Save the following as `fast-llm-tutorial/train-config.yaml`:
sequence_length: 4096
batch_size: 512 # (5)!
data:
format: file
path: fast-llm-tutorial/dataset/fast_llm_dataset.json # (6)!
split: [99, 1, 0] # (7)!
optimizer: # (8)!
datasets:
Training:
type: file
path: fast-llm-tutorial/dataset/fast_llm_config_training.yaml # (6)!
Validation:
type: file
path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml # (6)!
optimizer: # (7)!
weight_decay: 0.1
beta_1: 0.9
beta_2: 0.95
learning_rate: # (9)!
learning_rate: # (8)!
base: 6.0e-04
minimum: 6.0e-05
decay_style: cosine
decay_iterations: 100_000
warmup_iterations: 2000
pretrained:
format: llama # (10)!
format: llama # (9)!
path: fast-llm-tutorial/pretrained-model
model_weights: yes # (11)!
model_weights: yes # (10)!
model:
base_model:
transformer:
use_flash_attention: yes # (12)!
cross_entropy_impl: fused # (13)!
use_flash_attention: yes # (11)!
cross_entropy_impl: fused # (12)!
multi_stage:
zero_stage: 2 # (14)!
zero_stage: 2 # (13)!
distributed:
training_dtype: bf16 # (15)!
training_dtype: bf16 # (14)!
run:
experiment_dir: fast-llm-tutorial/experiment
```
Expand All @@ -592,15 +608,14 @@ Save the following as `fast-llm-tutorial/train-config.yaml`:
4. Adjust the number of sequences per GPU based on GPU memory. Considering a 4k token sequence length and 80GB GPUs, a `micro_batch_size` of 1 should work well.
5. Must be divisible by the number of GPUs and the `micro_batch_size`. At 4k tokens per sequence, 512 corresponds to about 2.1 million tokens per batch.
6. Location of the dataset metadata file generated in Step 4.
7. 99% train, 1% validation, 0% test. These settings need to be adjusted based on the size of your dataset. If you're using a smaller dataset, you need to increase the validation split.
8. These are good default optimizer settings for training models.
9. We are using a cosine decay schedule with linear warmup. After reaching the peak learning rate `base` at `warmup_iterations`, the learning rate will decay to `minimum` at `decay_iterations`, following a cosine curve. The minimum learning rate should be 1/10th of the base learning rate per Chinchilla.
10. Format of the pretrained model. Since it's a Llama model, we set this to `llama`.
11. We want to continue training Llama-3.1-8B from a checkpoint. If you're training from scratch, set this to `no`.
12. By default, Fast-LLM uses FlashAttention for faster training. If you're using Volta GPUs, set this to `no`.
13. Configure Fast-LLM to use the fused cross-entropy loss implementation rather than the default Triton implementation for models with a large vocabulary size such as Llama-3.1-8B. This avoids issues with block size limitations in our current Triton code.
14. We are using ZeRO stage 2 for this tutorial. You can set this to `1`, `2`, or `3` for ZeRO-1, ZeRO-2, or ZeRO-3, respectively.
15. `bf16` (bfloat16, or Brain Floating Point 16) is supported on Ampere GPUs and higher. On Volta GPUs, use `fp16` (half-precision floating point) for training instead of `bf16`.
7. These are good default optimizer settings for training models.
8. We are using a cosine decay schedule with linear warmup. After reaching the peak learning rate `base` at `warmup_iterations`, the learning rate will decay to `minimum` at `decay_iterations`, following a cosine curve. The minimum learning rate should be 1/10th of the base learning rate per Chinchilla.
9. Format of the pretrained model. Since it's a Llama model, we set this to `llama`.
10. We want to continue training Llama-3.1-8B from a checkpoint. If you're training from scratch, set this to `no`.
11. By default, Fast-LLM uses FlashAttention for faster training. If you're using Volta GPUs, set this to `no`.
12. Configure Fast-LLM to use the fused cross-entropy loss implementation rather than the default Triton implementation for models with a large vocabulary size such as Llama-3.1-8B. This avoids issues with block size limitations in our current Triton code.
13. We are using ZeRO stage 2 for this tutorial. You can set this to `1`, `2`, or `3` for ZeRO-1, ZeRO-2, or ZeRO-3, respectively.
14. `bf16` (bfloat16, or Brain Floating Point 16) is supported on Ampere GPUs and higher. On Volta GPUs, use `fp16` (half-precision floating point) for training instead of `bf16`.

## 🔑 (Optional) Step 6: Add Your Weights & Biases API Key

Expand Down
185 changes: 185 additions & 0 deletions docs/recipes/data-configuration.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
---
title: Configuring Data for Training
---

In this section we show how to configure datasets through a series of examples

We already saw an example dataset configuration in the [quick-start guide](../quick-start.md), where we prepared a simple dataset and split it into training and validation sub-datasets, and used these to train a small model. This was done by:

1. Defining a dataset preparation configuration.
2. Running `fast-llm prepare` with said configuration. This generated some binary files along with two fast-llm configuration files, `fast-llm-tutorial/dataset/fast_llm_config_training.yaml` and `fast-llm-tutorial/dataset/fast_llm_config_validation.yaml`.
3. Defining a fast-llm data configuration that use those datasets:

```yaml
data:
datasets:
Training:
type: file
path: fast-llm-tutorial/dataset/fast_llm_config_training.yaml
Validation:
type: file
path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml
```
4. Running `fast-llm training` with said configuration.

In this section we are interested in generalizing step 3. For more details on steps 1 and 2, please refer to the quick-start guide or [this example](data-configuration.md).

## Example 1: Blending multiple datasets

In this example, we have three datasets and want to sample from each of them during training with probabilities 0.70, 0.25 and 0.05. For this, we use the `blended` type which takes other datasets as arguments:

```yaml
data:
datasets:
Training:
type: blended
datasets:
- type: file
path: path/to/dataset_0.yaml
- type: file
path: path/to/dataset_1.yaml
- type: file
path: path/to/dataset_2.yaml
weights: [0.70, 0.25, 0.05]
```

!!! note "Dataset wrappers"
The `blended` dataset wrapper is one example of the many dataset wrappers available in fast-llm. Such wrappers may be nested (almost) arbitrarily to generate the dataset scheme that fits your needs. Fast-LLM will use the `type` argument to dynamically select the appropriate configuration class(es). With some effort you can even create your own wrapper!

## Example 2: Configure shuffling

In this example, we have a large dataset that comes pre-shuffled, so shuffling in unnecessary for the first epoch.

```yaml
data:
datasets:
Training:
type: file
path: path/to/dataset.yaml
sampling:
shuffle: skip_first_epoch
```

## Example 3: Disable shuffling for validation

In this example, we want to disable shuffling entirely, but only for the validation dataset. We can do this with the `sampled` dataset wrapper:

```yaml
data:
datasets:
Training:
type: file
path: path/to/training_dataset.yaml
Validation:
type: sampled
dataset:
type: file
path: path/to/validation_dataset.yaml
sampling:
shuffle: disabled
```

!!! note "More about sampling configuration"
Sampling parameters may be globally defined through data configuration (example 2), dataset wrapper(s) (examples 3, 4), or both (example 5). In the case where a dataset sampling is configured with both methods (or multiple nested wrappers), (innermost) wrapper overrides the data (or next-to-innermost wrapper) for the explicitly defined fields (and only those).

## Example 4: Set sampling seed for individual datasets

In this example, we have a blend of datasets as in example 1, but we wish to set the seed for each dataset individually for reproducibility reasons. For this, we use the `seed` field of the `sampling` wrapper:

```yaml
data:
datasets:
Training:
type: blended
datasets:
- type: sampled
dataset:
type: file
path: path/to/dataset_0.yaml
sampling:
seed:1234
- type: sampled
dataset:
type: file
path: path/to/dataset_0.yaml
sampling:
seed:2345
- type: sampled
dataset:
type: file
path: path/to/dataset_0.yaml
sampling:
seed:3456
weights: [0.70, 0.25, 0.05]
```

!!! note "Default seed"
In the absence of explicit seed, Fast-LLM uses a default seed (`data.sampling`'s default) instead, and uses seed shifts to ensure different seeds for each phase and for the various blended datasets.

## Example 5: Advanced scenario

In this example, we combine everything we learned so far to create a complex scenario, where:

* The training dataset is a blend consists of two datasets, one of them being itself a blend of three datasets.
* All datasets except for one come pre-shuffled, so can skip shuffling for the first epoch.
* We want to set the seed explicitly for the validation and innermost blended datasets, but keep the default seed for the others.

```yaml
data:
datasets:
Training:
type: blended
datasets:
- type: sampled
dataset:
type: blended
datasets:
- type: file
# Seed = 1234
path: path/to/dataset_0.yaml
- type: file
# Seed = 1234 + blend_shift, shuffle = skip_first_epoch
path: path/to/dataset_1.yaml
- type: sampled
dataset:
type: file
# Seed = 1234 + 2 * blend_shift, shuffle = epoch
path: path/to/dataset_2.yaml
sampling:
# Shuffle each epoch independently (default shuffling)
shuffle: epoch
sampling:
seed: 1234
- type: file
# Seed = default + train_shift + 2 * blend_shift, shuffle = skip_first_epoch
path: path/to/dataset_3.yaml
weights: [0.70, 0.25, 0.05]
Validation:
type: sampled
dataset:
type: file
# Seed = 2345, shuffle = skip_first_epoch
path: path/to/validation_dataset.yaml
sampling:
seed: 2345
sampling:
shuffle: skip_first_epoch
```

!!! note "Configure from file"
If a dataset configuration is especially complex and makes the dataset configuration excessively big, or is reused across many experiments, you may want to save it to a yaml file and refer to it un the config using a `file` dataset. This can be used to reduce the present example to
```yaml
data:
datasets:
Training:
type: file
path: path/to/training_dataset_config.yaml
Validation:
type: file
path: path/to/validation_dataset_config.yaml
sampling:
shuffle: skip_first_epoch
```
In fact, all the elementary datasets from file we've been using so far are of this format, and consist of more elementary `memmap` datasets optionally wrapped with `blended` and/or `slice` wrappers.
4 changes: 1 addition & 3 deletions fast_llm/data/dataset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,6 @@ def build_and_sample(
from fast_llm.data.dataset.blended import BlendedDataset

# Build and sample the datasets.
# TODO: Vary the seed?
# Add 5 times the standard deviation (of a binomial distribution)
# so the probability of sampling more than this amount during blending is negligible.

sampled_datasets = [
dataset.build_and_sample(
Expand All @@ -230,6 +227,7 @@ def build_and_sample(
if self.legacy
else math.ceil(weight * sampling.num_samples) + 1
),
# TODO: Seed may not be unique for nested blended datasets.
config=sampling.config.to_copy({"seed": sampling.config.seed + i * (0 if self.legacy else 697)}),
),
)
Expand Down
Loading

0 comments on commit 6df65c1

Please sign in to comment.