From 6df65c16f627fef6b3c570a59c6b3699f652b4df Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 4 Mar 2025 16:36:04 -0500 Subject: [PATCH] Dataset from file (#146) Co-authored-by: Torsten Scholak Co-authored-by: Oleksiy Ostapenko --- docs/quick-start.md | 83 ++++---- docs/recipes/data-configuration.md | 185 ++++++++++++++++++ fast_llm/data/dataset/config.py | 4 +- fast_llm/data/dataset/gpt/config.py | 61 +++++- fast_llm/data/dataset/gpt/memmap.py | 29 ++- fast_llm/data/dataset/gpt/sampled.py | 12 +- fast_llm/data/preparator/gpt_memmap/config.py | 6 + .../data/preparator/gpt_memmap/prepare.py | 176 +++++++++++++---- fast_llm/utils.py | 5 +- mkdocs.yaml | 3 +- tests/common.py | 12 +- tests/data/common.py | 56 +++++- tests/data/test_blending.py | 4 +- tests/data/test_dataset_from_file.py | 12 ++ tests/data/test_memmap.py | 6 +- tests/data/test_prepare_gpt_memmap.py | 95 +++++++++ 16 files changed, 648 insertions(+), 101 deletions(-) create mode 100644 docs/recipes/data-configuration.md create mode 100644 tests/data/test_dataset_from_file.py diff --git a/docs/quick-start.md b/docs/quick-start.md index abecc96c..ee228fb9 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -224,7 +224,8 @@ Choose based on your goals for this tutorial. For this tutorial, we'll use text from the [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/) dataset. This dataset is a free approximation of the WebText data OpenAI used for GPT-2, and it's perfect for our test run! -Create a configuration file for the dataset preparation. Copy the following content: +Create a configuration file for the dataset preparation. +Save the following as `./fast-llm-tutorial/prepare-config.yaml``: === "Small" @@ -242,10 +243,15 @@ Create a configuration file for the dataset preparation. Copy the following cont tokenizer: path: fast-llm-tutorial/pretrained-model + + splits: # (3)! + training: 0.9 + validation: 0.1 ``` 1. Processing speed scales linearly with the number of CPUs. 2. This small dataset restricts to the first 10K records of the OpenWebText dataset to speed up the process. If you want to use the full dataset, replace with `openwebtext`. + 3. 90% train, 10% validation. These settings need to be adjusted based on the size of your dataset. === "Big" @@ -263,11 +269,14 @@ Create a configuration file for the dataset preparation. Copy the following cont tokenizer: path: fast-llm-tutorial/pretrained-model + + splits: # (2)! + training: 0.99 + validation: 0.01 ``` 1. Processing speed scales linearly with the number of CPUs. - -Save it as `./fast-llm-tutorial/prepare-config.yaml`. + 2. 99% train, 1% validation. These settings need to be adjusted based on the size of your dataset. Fast-LLM ships with a `prepare` command that will download and preprocess the dataset for you. @@ -498,22 +507,26 @@ Save the following as `fast-llm-tutorial/train-config.yaml`: sequence_length: 1024 batch_size: 480 # (5)! data: - format: file - path: fast-llm-tutorial/dataset/fast_llm_dataset.json # (6)! - split: [9, 1, 0] # (7)! + datasets: + Training: + type: file + path: fast-llm-tutorial/dataset/fast_llm_config_training.yaml # (6)! + Validation: + type: file + path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml # (6)! optimizer: learning_rate: base: 6.0e-04 pretrained: - format: llama # (8)! + format: llama # (7)! path: fast-llm-tutorial/pretrained-model - model_weights: no # (9)! + model_weights: no # (8)! model: base_model: transformer: - use_flash_attention: yes # (10)! + use_flash_attention: yes # (9)! distributed: - training_dtype: bf16 # (11)! + training_dtype: bf16 # (10)! run: experiment_dir: fast-llm-tutorial/experiment ``` @@ -521,10 +534,9 @@ Save the following as `fast-llm-tutorial/train-config.yaml`: 1. For the small run, we'll stop after 100 iterations. 2. The trained model will be saved in `Transformers` Llama format to `fast-llm-tutorial/experiment/export/llama/100` at the end of the small run. You can also save as a `Fast-LLM` checkpoint by setting the `format` to `fast_llm`. 3. Entirely optional, but it's a good idea to track your training progress with Weights & Biases. Replace `null` with your own W&B entity name. If you don't want to use W&B, just ignore this section. - 3. Adjust the number of sequences per GPU based on GPU memory. For SmolLM2-135M at 1024 sequenced length and a 80GB GPU, a `micro_batch_size` of 60 should work well. - 4. Must be divisible by the number of GPUs and the `micro_batch_size`. At 1024 tokens per sequence, 480 corresponds to about 500,000 tokens per batch. - 5. Location of the dataset metadata file generated in Step 4. - 6. 90% train, 10% validation, 0% test. These settings need to be adjusted based on the size of your dataset. + 4. Adjust the number of sequences per GPU based on GPU memory. For SmolLM2-135M at 1024 sequenced length and a 80GB GPU, a `micro_batch_size` of 60 should work well. + 5. Must be divisible by the number of GPUs and the `micro_batch_size`. At 1024 tokens per sequence, 480 corresponds to about 500,000 tokens per batch. + 6. Location of the dataset metadata files generated in Step 4. 7. Format of the pretrained model. Since SmolLM is a Llama model, we set this to `llama`. 8. We'll train SmolLM2-135M from scratch. You can set to `yes` to continue training from a checkpoint (if you put one in the model directory). 9. By default, Fast-LLM uses FlashAttention for faster training. If you're using Volta GPUs, set this to `no`. @@ -556,32 +568,36 @@ Save the following as `fast-llm-tutorial/train-config.yaml`: sequence_length: 4096 batch_size: 512 # (5)! data: - format: file - path: fast-llm-tutorial/dataset/fast_llm_dataset.json # (6)! - split: [99, 1, 0] # (7)! - optimizer: # (8)! + datasets: + Training: + type: file + path: fast-llm-tutorial/dataset/fast_llm_config_training.yaml # (6)! + Validation: + type: file + path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml # (6)! + optimizer: # (7)! weight_decay: 0.1 beta_1: 0.9 beta_2: 0.95 - learning_rate: # (9)! + learning_rate: # (8)! base: 6.0e-04 minimum: 6.0e-05 decay_style: cosine decay_iterations: 100_000 warmup_iterations: 2000 pretrained: - format: llama # (10)! + format: llama # (9)! path: fast-llm-tutorial/pretrained-model - model_weights: yes # (11)! + model_weights: yes # (10)! model: base_model: transformer: - use_flash_attention: yes # (12)! - cross_entropy_impl: fused # (13)! + use_flash_attention: yes # (11)! + cross_entropy_impl: fused # (12)! multi_stage: - zero_stage: 2 # (14)! + zero_stage: 2 # (13)! distributed: - training_dtype: bf16 # (15)! + training_dtype: bf16 # (14)! run: experiment_dir: fast-llm-tutorial/experiment ``` @@ -592,15 +608,14 @@ Save the following as `fast-llm-tutorial/train-config.yaml`: 4. Adjust the number of sequences per GPU based on GPU memory. Considering a 4k token sequence length and 80GB GPUs, a `micro_batch_size` of 1 should work well. 5. Must be divisible by the number of GPUs and the `micro_batch_size`. At 4k tokens per sequence, 512 corresponds to about 2.1 million tokens per batch. 6. Location of the dataset metadata file generated in Step 4. - 7. 99% train, 1% validation, 0% test. These settings need to be adjusted based on the size of your dataset. If you're using a smaller dataset, you need to increase the validation split. - 8. These are good default optimizer settings for training models. - 9. We are using a cosine decay schedule with linear warmup. After reaching the peak learning rate `base` at `warmup_iterations`, the learning rate will decay to `minimum` at `decay_iterations`, following a cosine curve. The minimum learning rate should be 1/10th of the base learning rate per Chinchilla. - 10. Format of the pretrained model. Since it's a Llama model, we set this to `llama`. - 11. We want to continue training Llama-3.1-8B from a checkpoint. If you're training from scratch, set this to `no`. - 12. By default, Fast-LLM uses FlashAttention for faster training. If you're using Volta GPUs, set this to `no`. - 13. Configure Fast-LLM to use the fused cross-entropy loss implementation rather than the default Triton implementation for models with a large vocabulary size such as Llama-3.1-8B. This avoids issues with block size limitations in our current Triton code. - 14. We are using ZeRO stage 2 for this tutorial. You can set this to `1`, `2`, or `3` for ZeRO-1, ZeRO-2, or ZeRO-3, respectively. - 15. `bf16` (bfloat16, or Brain Floating Point 16) is supported on Ampere GPUs and higher. On Volta GPUs, use `fp16` (half-precision floating point) for training instead of `bf16`. + 7. These are good default optimizer settings for training models. + 8. We are using a cosine decay schedule with linear warmup. After reaching the peak learning rate `base` at `warmup_iterations`, the learning rate will decay to `minimum` at `decay_iterations`, following a cosine curve. The minimum learning rate should be 1/10th of the base learning rate per Chinchilla. + 9. Format of the pretrained model. Since it's a Llama model, we set this to `llama`. + 10. We want to continue training Llama-3.1-8B from a checkpoint. If you're training from scratch, set this to `no`. + 11. By default, Fast-LLM uses FlashAttention for faster training. If you're using Volta GPUs, set this to `no`. + 12. Configure Fast-LLM to use the fused cross-entropy loss implementation rather than the default Triton implementation for models with a large vocabulary size such as Llama-3.1-8B. This avoids issues with block size limitations in our current Triton code. + 13. We are using ZeRO stage 2 for this tutorial. You can set this to `1`, `2`, or `3` for ZeRO-1, ZeRO-2, or ZeRO-3, respectively. + 14. `bf16` (bfloat16, or Brain Floating Point 16) is supported on Ampere GPUs and higher. On Volta GPUs, use `fp16` (half-precision floating point) for training instead of `bf16`. ## 🔑 (Optional) Step 6: Add Your Weights & Biases API Key diff --git a/docs/recipes/data-configuration.md b/docs/recipes/data-configuration.md new file mode 100644 index 00000000..ba3fe91e --- /dev/null +++ b/docs/recipes/data-configuration.md @@ -0,0 +1,185 @@ +--- +title: Configuring Data for Training +--- + +In this section we show how to configure datasets through a series of examples + +We already saw an example dataset configuration in the [quick-start guide](../quick-start.md), where we prepared a simple dataset and split it into training and validation sub-datasets, and used these to train a small model. This was done by: + +1. Defining a dataset preparation configuration. +2. Running `fast-llm prepare` with said configuration. This generated some binary files along with two fast-llm configuration files, `fast-llm-tutorial/dataset/fast_llm_config_training.yaml` and `fast-llm-tutorial/dataset/fast_llm_config_validation.yaml`. +3. Defining a fast-llm data configuration that use those datasets: + + ```yaml + data: + datasets: + Training: + type: file + path: fast-llm-tutorial/dataset/fast_llm_config_training.yaml + Validation: + type: file + path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml + ``` + +4. Running `fast-llm training` with said configuration. + +In this section we are interested in generalizing step 3. For more details on steps 1 and 2, please refer to the quick-start guide or [this example](data-configuration.md). + +## Example 1: Blending multiple datasets + +In this example, we have three datasets and want to sample from each of them during training with probabilities 0.70, 0.25 and 0.05. For this, we use the `blended` type which takes other datasets as arguments: + +```yaml +data: + datasets: + Training: + type: blended + datasets: + - type: file + path: path/to/dataset_0.yaml + - type: file + path: path/to/dataset_1.yaml + - type: file + path: path/to/dataset_2.yaml + weights: [0.70, 0.25, 0.05] +``` + +!!! note "Dataset wrappers" + The `blended` dataset wrapper is one example of the many dataset wrappers available in fast-llm. Such wrappers may be nested (almost) arbitrarily to generate the dataset scheme that fits your needs. Fast-LLM will use the `type` argument to dynamically select the appropriate configuration class(es). With some effort you can even create your own wrapper! + +## Example 2: Configure shuffling + +In this example, we have a large dataset that comes pre-shuffled, so shuffling in unnecessary for the first epoch. + +```yaml +data: + datasets: + Training: + type: file + path: path/to/dataset.yaml + sampling: + shuffle: skip_first_epoch +``` + +## Example 3: Disable shuffling for validation + +In this example, we want to disable shuffling entirely, but only for the validation dataset. We can do this with the `sampled` dataset wrapper: + +```yaml +data: + datasets: + Training: + type: file + path: path/to/training_dataset.yaml + Validation: + type: sampled + dataset: + type: file + path: path/to/validation_dataset.yaml + + sampling: + shuffle: disabled +``` + +!!! note "More about sampling configuration" + Sampling parameters may be globally defined through data configuration (example 2), dataset wrapper(s) (examples 3, 4), or both (example 5). In the case where a dataset sampling is configured with both methods (or multiple nested wrappers), (innermost) wrapper overrides the data (or next-to-innermost wrapper) for the explicitly defined fields (and only those). + +## Example 4: Set sampling seed for individual datasets + +In this example, we have a blend of datasets as in example 1, but we wish to set the seed for each dataset individually for reproducibility reasons. For this, we use the `seed` field of the `sampling` wrapper: + +```yaml +data: + datasets: + Training: + type: blended + datasets: + - type: sampled + dataset: + type: file + path: path/to/dataset_0.yaml + sampling: + seed:1234 + - type: sampled + dataset: + type: file + path: path/to/dataset_0.yaml + sampling: + seed:2345 + - type: sampled + dataset: + type: file + path: path/to/dataset_0.yaml + sampling: + seed:3456 + weights: [0.70, 0.25, 0.05] +``` + +!!! note "Default seed" + In the absence of explicit seed, Fast-LLM uses a default seed (`data.sampling`'s default) instead, and uses seed shifts to ensure different seeds for each phase and for the various blended datasets. + +## Example 5: Advanced scenario + +In this example, we combine everything we learned so far to create a complex scenario, where: + +* The training dataset is a blend consists of two datasets, one of them being itself a blend of three datasets. +* All datasets except for one come pre-shuffled, so can skip shuffling for the first epoch. +* We want to set the seed explicitly for the validation and innermost blended datasets, but keep the default seed for the others. + +```yaml +data: + datasets: + Training: + type: blended + datasets: + - type: sampled + dataset: + type: blended + datasets: + - type: file + # Seed = 1234 + path: path/to/dataset_0.yaml + - type: file + # Seed = 1234 + blend_shift, shuffle = skip_first_epoch + path: path/to/dataset_1.yaml + - type: sampled + dataset: + type: file + # Seed = 1234 + 2 * blend_shift, shuffle = epoch + path: path/to/dataset_2.yaml + sampling: + # Shuffle each epoch independently (default shuffling) + shuffle: epoch + sampling: + seed: 1234 + - type: file + # Seed = default + train_shift + 2 * blend_shift, shuffle = skip_first_epoch + path: path/to/dataset_3.yaml + weights: [0.70, 0.25, 0.05] + Validation: + type: sampled + dataset: + type: file + # Seed = 2345, shuffle = skip_first_epoch + path: path/to/validation_dataset.yaml + sampling: + seed: 2345 + sampling: + shuffle: skip_first_epoch +``` + +!!! note "Configure from file" + If a dataset configuration is especially complex and makes the dataset configuration excessively big, or is reused across many experiments, you may want to save it to a yaml file and refer to it un the config using a `file` dataset. This can be used to reduce the present example to + ```yaml + data: + datasets: + Training: + type: file + path: path/to/training_dataset_config.yaml + Validation: + type: file + path: path/to/validation_dataset_config.yaml + sampling: + shuffle: skip_first_epoch + ``` + In fact, all the elementary datasets from file we've been using so far are of this format, and consist of more elementary `memmap` datasets optionally wrapped with `blended` and/or `slice` wrappers. diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 58d00c95..431a28a0 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -216,9 +216,6 @@ def build_and_sample( from fast_llm.data.dataset.blended import BlendedDataset # Build and sample the datasets. - # TODO: Vary the seed? - # Add 5 times the standard deviation (of a binomial distribution) - # so the probability of sampling more than this amount during blending is negligible. sampled_datasets = [ dataset.build_and_sample( @@ -230,6 +227,7 @@ def build_and_sample( if self.legacy else math.ceil(weight * sampling.num_samples) + 1 ), + # TODO: Seed may not be unique for nested blended datasets. config=sampling.config.to_copy({"seed": sampling.config.seed + i * (0 if self.legacy else 697)}), ), ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 80788922..d6cebd75 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -6,8 +6,10 @@ import typing import warnings +import yaml + from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none -from fast_llm.data.dataset.abstract import SampledDataset +from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.dataset.config import ( BlendedDatasetConfig, ConcatenatedDatasetConfig, @@ -164,11 +166,21 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.", hint=FieldHint.core, ) + num_documents: int | None = Field( + default=None, + desc="Expected number of documents in the dataset.", + hint=FieldHint.optional, + ) + num_tokens: int | None = Field( + default=None, + desc="Expected number of tokens in the dataset.", + hint=FieldHint.optional, + ) def build(self) -> "GPTMemmapDataset": from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset - return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path) + return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens) @config_class() @@ -209,8 +221,48 @@ class GPTBlendedDatasetConfig(BlendedDatasetConfig, GPTSampledDatasetConfig): datasets: list[GPTSampledDatasetConfig] = FieldUpdate() +@config_class() +class GPTDatasetFromFileConfig(GPTSamplableDatasetConfig): + _abstract: typing.ClassVar[bool] = False + type_: typing.ClassVar[str | None] = "file" + path: pathlib.Path = Field( + default=None, + desc="The path to a dataset config file.", + hint=FieldHint.core, + ) + + def build_and_sample(self, sampling: SamplingData) -> SampledDataset: + config = self._load_config() + return config.build_and_sample(sampling) + + def build(self) -> SamplableDataset: + config = self._load_config() + assert isinstance(config, GPTSamplableDatasetConfig) + return config.build() + + def _load_config(self): + assert self.path.is_file() + return GPTSampledDatasetConfig.from_dict(self._convert_paths(yaml.safe_load(self.path.open("r")))) + + def _convert_paths(self, config): + # Recursively convert paths relative to `self.path.parent` to make them relative to cwd. + # Assuming all path are in a field named "path" + # TODO: Find a more generic way + if isinstance(config, dict): + for key, value in config.items(): + self._convert_paths(value) + if "path" in config: + assert isinstance(config["path"], (str, pathlib.Path)) + config["path"] = self.path.parent / config["path"] + elif isinstance(config, list): + for value in config: + self._convert_paths(value) + return config + + @config_class() class GPTConcatenatedMemmapConfig(GPTIndexedDatasetConfig): + # TODO v0.3: Remove. _abstract: typing.ClassVar[bool] = False type_: typing.ClassVar[str | None] = "concatenated_memmap" path: pathlib.Path = Field( @@ -219,8 +271,11 @@ class GPTConcatenatedMemmapConfig(GPTIndexedDatasetConfig): hint=FieldHint.core, ) + def _validate(self) -> None: + warnings.warn("`concatenated_memmap` dataset is deprecated. Use `file` instead.", DeprecationWarning) + super()._validate() + def build(self) -> "GPTConcatenatedDataset": - pass assert self.path.is_dir() index_path = self.path / "index.txt" diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 3f6d1784..c95b3705 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -20,10 +20,16 @@ class GPTMemmapDataset(GPTIndexedDataset): See https://github.com/NVIDIA/Megatron-LM?tab=readme-ov-file#data-preprocessing for more details. """ - def __init__(self, name: str, prefix: pathlib.Path | str): - self._init(name, prefix) - - def _init(self, name: str, prefix: pathlib.Path | str) -> None: + def __init__( + self, + name: str, + prefix: pathlib.Path | str, + num_documents: int | None = None, + num_tokens: int | None = None, + ): + self._init(name, prefix, num_documents, num_tokens) + + def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None, num_tokens: int | None) -> None: super().__init__() self._name = name self._prefix = pathlib.Path(prefix) @@ -41,6 +47,9 @@ def _init(self, name: str, prefix: pathlib.Path | str) -> None: _ = struct.unpack(" None: self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) - def __getstate__(self) -> tuple[str, pathlib.Path]: - return (self._name, self._prefix) + self._num_tokens = div(self._bin_buffer_mmap.size, np.dtype(self._dtype).itemsize) + if num_tokens is not None: + assert self._num_tokens == num_tokens + + def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: + return (self._name, self._prefix, self._num_documents, self._num_tokens) - def __setstate__(self, state: tuple[str, pathlib.Path]): + def __setstate__(self, state: tuple[str, pathlib.Path, int | None, int | None]): self._init(*state) def __del__(self): @@ -120,7 +133,7 @@ def __len__(self) -> int: @property def num_tokens(self) -> int: - return div(self._bin_buffer_mmap.size, np.dtype(self._dtype).itemsize) + return self._num_tokens def get_document_sizes(self) -> np.ndarray: """ diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index e88a4efe..9fa830fd 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -154,10 +154,20 @@ def _sample(self) -> None: "config": self._config.to_serialized(), } self._load_yaml_data(yaml_data) + if self._yaml_path is not None: if self._yaml_path.is_file(): - Assert.eq(yaml.safe_load(self._yaml_path.open("r")), yaml_data) + loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) + if loaded_yaml_data != yaml_data: + raise RuntimeError( + f"Invalid dataset cache for dataset {self.name}." + " If this is due to an intended configuration change," + " please delete the cache before continuing." + f"\nCurrent config:\n{yaml.safe_dump(yaml_data)}" + f"\nCached config:\n{yaml.safe_dump(loaded_yaml_data)}" + ) # Dataset is already sampled, skip. + logger.info(f"Using existing sampling for dataset {self.name}") return else: self._yaml_path.parent.mkdir(parents=True, exist_ok=True) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 63f20bf3..2c4311c3 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -158,6 +158,12 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): desc="Configuration for the tokenizer.", hint=FieldHint.feature, ) + splits: dict[str, float] | None = Field( + default=None, + desc="Split the output dataset into multiple ones (ex, train/valid/test) with the specified ratios." + " Does not shuffle samples.", + hint=FieldHint.optional, + ) def _validate(self) -> None: assert self.tokenizer.path is not None diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index e029137c..77995970 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -12,13 +12,21 @@ import torch.distributed import tqdm import transformers - +import yaml + +from fast_llm.data.dataset.gpt.config import ( + GPTBlendedDatasetConfig, + GPTDatasetSliceConfig, + GPTIndexedDatasetConfig, + GPTMemmapDatasetConfig, +) from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig from fast_llm.data.tokenizer import Tokenizer -from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type +from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum logger = logging.getLogger(__name__) @@ -65,7 +73,7 @@ def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict "num_tokens": num_tokens, } - def _save_shard(self, args: tuple[int, datasets.Dataset]) -> dict[str, typing.Any]: + def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetConfig: shard_idx, shard_dataset = args prefix = f"shard_{self._config.distributed.rank}_{shard_idx}" shard_output_path = self._config.output_path / prefix @@ -83,12 +91,14 @@ def _document_generator(): GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) - dataset_dict = { - "prefix": prefix, - "num_documents": len(shard_dataset), # Use the length of the shard dataset directly - "num_tokens": sum(len(doc["input_ids"]) for doc in shard_dataset), - } - return dataset_dict + return GPTMemmapDatasetConfig.from_dict( + { + "type": "memmap", + "path": prefix, + "num_documents": len(shard_dataset), # Use the length of the shard dataset directly + "num_tokens": sum(len(doc["input_ids"]) for doc in shard_dataset), + } + ) def _load_dataset(self) -> datasets.Dataset: dataset = datasets.load_dataset( @@ -158,20 +168,12 @@ def run(self) -> None: # Load tokenizer self._tokenizer = Tokenizer(config=self._config.tokenizer) - # Set data type if not provided - if self._config.dataset.data_type is None: - # Decide the datatype based on the tokenizer vocabulary size - vocab_size = self._tokenizer.vocab_size - if vocab_size <= np.iinfo(np.int16).max: - self._data_type = DataType.int16 - # elif vocab_size <= np.iinfo(np.uint16).max: - # self._data_type = DataType.uint16 # Not supported by Fast-LLM's DataType - elif vocab_size <= np.iinfo(np.int32).max: - self._data_type = DataType.int32 - else: - raise ValueError(f"Tokenizer vocabulary size {vocab_size} is too large. This is likely an error.") - else: - self._data_type = self._config.dataset.data_type + # Decide the datatype based on the tokenizer vocabulary size + self._data_type = ( + get_unsigned_integer_type(self._tokenizer.vocab_size) + if self._config.dataset.data_type is None + else self._config.dataset.data_type + ) # Initialize distributed processing if self._config.distributed.world_size > 1: @@ -238,32 +240,130 @@ def run(self) -> None: # Use multiprocessing to save each shard in parallel on all ranks with multiprocessing.Pool(processes=self._config.saving_workers) as pool: - dataset_dicts = pool.map(self._save_shard, shards) + dataset_configs = pool.map(self._save_shard, shards) # Gather dataset_dicts from all ranks to rank 0 if self._config.distributed.world_size > 1: if self._config.distributed.rank == 0: - all_dataset_dicts = [None] * self._config.distributed.world_size - torch.distributed.gather_object(dataset_dicts, all_dataset_dicts, dst=0) - dataset_dicts = [item for sublist in all_dataset_dicts for item in sublist] + all_dataset_configs = [None] * self._config.distributed.world_size + torch.distributed.gather_object(dataset_configs, all_dataset_configs, dst=0) + dataset_configs = [item for sublist in all_dataset_configs for item in sublist] else: - torch.distributed.gather_object(dataset_dicts, [], dst=0) + torch.distributed.gather_object(dataset_configs, [], dst=0) - # Create a metadata file on rank 0 if self._config.distributed.rank == 0: - total_tokens = sum(dataset_dict["num_tokens"] for dataset_dict in dataset_dicts) - for dataset_dict in dataset_dicts: - dataset_dict["weight"] = float(dataset_dict["num_tokens"]) / float(total_tokens) - output_file = self._config.output_path / "fast_llm_dataset.json" - json.dump({"datasets": dataset_dicts}, output_file.open("w")) + # Create the config file(s) on rank 0 + if self._config.splits: + for split_name, split_config in self._split_and_blend_dataset_configs( + dataset_configs, self._config.splits + ).items(): + self._save_dataset_config( + split_config, self._config.output_path / f"fast_llm_config_{split_name}.yaml" + ) + else: + self._save_dataset_config( + self._blend_dataset_configs(dataset_configs), self._config.output_path / f"fast_llm_config.yaml" + ) + # Save metadata on rank 0 self._save_croissant_metadata() - # Create an index file on rank 0 - index_file = self._config.output_path / "index.txt" - index_file.open("w").writelines([dataset_dict["prefix"] + "\n" for dataset_dict in dataset_dicts]) - # Finalize distributed processing if self._config.distributed.world_size > 1: torch.distributed.barrier() torch.distributed.destroy_process_group() + + @classmethod + def _save_dataset_config(cls, dataset_config: GPTIndexedDatasetConfig, output_path: pathlib.Path) -> None: + logger.info(f"Saving config to {output_path}") + yaml.safe_dump( + dataset_config.to_serialized(), + output_path.open("w"), + ) + + @classmethod + def _blend_dataset_configs(cls, dataset_configs: list[GPTMemmapDatasetConfig]) -> GPTIndexedDatasetConfig: + if len(dataset_configs) == 1: + return dataset_configs[0] + return GPTIndexedDatasetConfig.from_dict( + { + "type": "blended", + "datasets": dataset_configs, + "weights": [dataset_config.num_tokens for dataset_config in dataset_configs], + } + ) + + @classmethod + def _split_and_blend_dataset_configs( + cls, dataset_configs: list[GPTMemmapDatasetConfig], splits: dict[str, int | float] + ) -> dict[str, GPTIndexedDatasetConfig]: + split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() + dataset_sizes = [dataset_config.num_tokens for dataset_config in dataset_configs] + dataset_probabilities = normalize_probabilities(dataset_sizes) + dataset_cumsums = padded_cumsum(dataset_probabilities).tolist() + dataset_splits = {} + + for split_index, split_name in enumerate(splits): + datasets_in_split = [] + dataset_tokens_in_split = [] + for dataset_index, dataset_config in enumerate(dataset_configs): + split_begin_in_dataset = max( + (split_cumsum[split_index] - dataset_cumsums[dataset_index]) + / dataset_probabilities[dataset_index], + 0, + ) + split_end_in_dataset = min( + (split_cumsum[split_index + 1] - dataset_cumsums[dataset_index]) + / dataset_probabilities[dataset_index], + 1, + ) + if split_begin_in_dataset == 0 and split_end_in_dataset == 1: + # All the dataset belongs to the split. + datasets_in_split.append(dataset_configs[dataset_index]) + dataset_tokens_in_split.append(dataset_sizes[dataset_index]) + elif split_end_in_dataset > split_begin_in_dataset: + # Part of the dataset belongs to the split. + sizes_cumsum = dataset_config.build().get_document_sizes().cumsum() + Assert.eq(sizes_cumsum[-1], dataset_config.num_tokens) + begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * dataset_config.num_tokens) + end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * dataset_config.num_tokens) + if end_index > begin_index: + datasets_in_split.append( + GPTDatasetSliceConfig.from_dict( + { + "type": "slice", + "dataset": dataset_configs[dataset_index], + "begin": begin_index / dataset_config.num_documents, + "end": end_index / dataset_config.num_documents, + } + ) + ) + dataset_tokens_in_split.append( + sizes_cumsum[end_index - 1].item() + - (sizes_cumsum[begin_index - 1].item() if begin_index > 0 else 0) + ) + + # [else] None of the dataset belongs to the split. + + if len(datasets_in_split) == 0: + # This is a big problem, but we don't want to crash the whole run. + logger.error(f"Datasets split {split_name} is empty!") + elif len(datasets_in_split) == 1: + dataset_splits[split_name] = datasets_in_split[0] + else: + dataset_splits[split_name] = GPTBlendedDatasetConfig.from_dict( + { + "type": "blended", + "datasets": datasets_in_split, + "weights": dataset_tokens_in_split, + } + ) + + return dataset_splits + + +def _get_nearest_split(cumsum: np.ndarray, value: float) -> int: + left = cumsum.searchsorted(value, side="right") + if left == len(cumsum): + return left.item() + return left + 1 if (value - cumsum[left]) / (cumsum[left + 1] - cumsum[left]) > 0.5 else left diff --git a/fast_llm/utils.py b/fast_llm/utils.py index c00e42ba..d650fa94 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -238,14 +238,15 @@ def log[ return logged -def normalize_probabilities(p: "npt.ArrayLike") -> list[float]: +def normalize_probabilities(p: "npt.ArrayLike", return_array: bool = False) -> "list[float] | np.ndarray": import numpy as np p = np.array(p) Assert.custom(lambda x: np.all(x >= 0), p) p_sum = p.sum() Assert.gt(p_sum, 0) - return (p / p_sum).tolist() + out = p / p_sum + return out if return_array else out.tolist() def set_nested_dict_value[ diff --git a/mkdocs.yaml b/mkdocs.yaml index 1d3a0892..47ac8cd6 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -167,7 +167,8 @@ nav: - StarCoder 2: success-stories/starcoder-2.md - License: license.md - Recipes: - - Data Preparation: recipes/data-preparation.md + - Prepare a dataset: recipes/data-preparation.md + - Configure a dataset: recipes/data-configuration.md - Train Llama 8B from scratch: recipes/train-llama-8b.md - Continue training Llama 8B: recipes/continue-training-llama-8b.md - Upcycle Llama 3B to MoE: recipes/upcycle-llama-3b-to-moe.md diff --git a/tests/common.py b/tests/common.py index 017a3ce0..6cec64e1 100644 --- a/tests/common.py +++ b/tests/common.py @@ -9,6 +9,7 @@ import numpy as np import pytest import torch +import yaml from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample @@ -38,7 +39,7 @@ TOKENIZER_PATH = TEST_RESULTS_PATH / "tokenizer" / "common" TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" DATASET_CACHE = TEST_RESULTS_PATH / "dataset" -DATASET_PREFIX = DATASET_CACHE / "common" +DATASET_PREFIX = DATASET_CACHE / "common" / "dataset" DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset" / "cache" TEST_VOCAB_SIZE = 8192 @@ -272,7 +273,11 @@ def get_test_dataset( transformers.AutoTokenizer.from_pretrained("bigcode/santacoder").save_pretrained(TOKENIZER_PATH) - if not (prefix.with_suffix(".idx").is_file() and prefix.with_suffix(".bin").is_file()): + if not ( + prefix.with_suffix(".idx").is_file() + and prefix.with_suffix(".bin").is_file() + and prefix.parent.joinpath("fast_llm_config.yaml").is_file() + ): import transformers texts = "".join(random.Random(seed).choices(characters, k=num_tokens)).splitlines() @@ -289,6 +294,9 @@ def get_test_dataset( sample.loss_masking_spans = span[: len(span) // 2 * 2].reshape(-1, 2) GPTMemmapDataset.write_dataset(prefix, samples) + yaml.safe_dump( + {"type": "memmap", "path": prefix.name}, prefix.parent.joinpath("fast_llm_config.yaml").open("w") + ) def get_test_concatenated_memmap_dataset( diff --git a/tests/data/common.py b/tests/data/common.py index 668377e3..d326a93b 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -4,11 +4,16 @@ import numpy as np import torch -from fast_llm.config import NoAutoValidate +from fast_llm.config import Field, FieldHint, NoAutoValidate, config_class from fast_llm.data.data.gpt.config import GPTDataConfig, GPTSamplingDefaultConfig from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.gpt.config import GPTSampledDatasetConfig, GPTSamplingData, ShufflingType +from fast_llm.data.dataset.gpt.config import ( + GPTIndexedDatasetConfig, + GPTSampledDatasetConfig, + GPTSamplingData, + ShufflingType, +) from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset from fast_llm.data.tokenizer import Tokenizer @@ -103,7 +108,7 @@ def compare_indexed_dataset( ) -> None: Assert.eq(len(dataset), length) sizes = dataset.get_document_sizes() - Assert.eq(sizes.sum(), num_tokens) + # Assert.eq(sizes.sum(), num_tokens) Assert.all_equal( [len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], sizes[: min(len(dataset), 100)] ) @@ -111,7 +116,6 @@ def compare_indexed_dataset( Assert.all_equal(dataset.get(i).token_ids, np.array(expected_sample, dtype=np.uint16)) if loss_masking_spans: for i, loss_masking_span in loss_masking_spans.items(): - print("AAAAAA", dataset.get(i, use_loss_masking_spans=True).loss_masking_spans, loss_masking_spans[i]) Assert.all_equal( dataset.get(i, use_loss_masking_spans=True).loss_masking_spans, np.array(loss_masking_spans[i], dtype=np.int32).reshape(-1, 2), @@ -163,3 +167,47 @@ def validate_indexed_dataset_sampling( if expected_samples is not None: Assert.all_equal(token_ids, expected_samples) return token_ids + + +@config_class() +class MockGPTMemmapDatasetConfig(GPTIndexedDatasetConfig): + _abstract: typing.ClassVar[bool] = False + type_: typing.ClassVar[str | None] = "mock_memmap" + num_documents: int | None = Field( + default=None, + desc="Expected number of documents in the dataset.", + hint=FieldHint.core, + ) + num_tokens_per_document: int | None = Field( + default=None, + desc="Expected number of tokens in the dataset.", + hint=FieldHint.optional, + ) + + def build(self) -> "GPTIndexedDataset": + return MockGPTMemmapDataset(self) + + @property + def num_tokens(self) -> int: + return self.num_documents * self.num_tokens_per_document + + +class MockGPTMemmapDataset(GPTIndexedDataset): + def __init__(self, config: MockGPTMemmapDatasetConfig): + self._config = config + + @property + def name(self) -> str: + return "mock_memmap" + + def __len__(self) -> int: + return self._config.num_documents + + def get_document_sizes(self) -> np.ndarray: + return np.full(self._config.num_documents, self._config.num_tokens_per_document, dtype=np.int64) + + def get_document_size(self, index: int) -> int: + return self._config.num_tokens_per_document + + def get(self, index: int, *args, **kwargs) -> typing.Any: + raise NotImplementedError() diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 56d84eaa..fa1bc2a9 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -6,7 +6,7 @@ from fast_llm.data.dataset.gpt.config import GPTBlendedDatasetConfig from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert, normalize_probabilities -from tests.common import DATASET_PREFIX, get_test_dataset +from tests.common import DATASET_CACHE, DATASET_PREFIX, get_test_dataset from tests.data.common import ( compare_sampled_dataset, get_dataset_config, @@ -14,7 +14,7 @@ get_test_data_and_compare_samples, ) -_DATASET_PREFIX_MIX_1 = DATASET_PREFIX.with_name("blended_mix_1") +_DATASET_PREFIX_MIX_1 = DATASET_CACHE / "blended_mix_1" / "dataset" def _get_test_dataset_mix_1(): diff --git a/tests/data/test_dataset_from_file.py b/tests/data/test_dataset_from_file.py new file mode 100644 index 00000000..4ac2fcdf --- /dev/null +++ b/tests/data/test_dataset_from_file.py @@ -0,0 +1,12 @@ +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig +from tests.common import DATASET_PREFIX, get_test_dataset +from tests.data.common import compare_indexed_dataset, get_dataset_config +from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_TOKENS + + +def test_dataset_from_file(): + get_test_dataset() + dataset_config = {"type": "file", "path": str(DATASET_PREFIX.parent.joinpath("fast_llm_config.yaml"))} + dataset = get_dataset_config(dataset_config, GPTDatasetFromFileConfig).build() + print("kjhbwiugfberibgiujebi", len(dataset)) + compare_indexed_dataset(dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES) diff --git a/tests/data/test_memmap.py b/tests/data/test_memmap.py index 6aaf83e8..be801220 100644 --- a/tests/data/test_memmap.py +++ b/tests/data/test_memmap.py @@ -3,7 +3,7 @@ import pytest from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig -from tests.common import DATASET_PREFIX, DATASET_SAMPLING_CACHE, get_test_dataset +from tests.common import DATASET_CACHE, DATASET_PREFIX, DATASET_SAMPLING_CACHE, get_test_dataset from tests.data.common import compare_indexed_dataset, get_dataset_config MEMMAP_DATASET_LENGTH = 6153 @@ -31,11 +31,11 @@ def test_gpt_memmap(cache_directory): 15: [], } -_DATASET_PREFIX_SPANS = DATASET_PREFIX.with_name("with_spans") +_DATASET_PREFIX_SPANS = DATASET_CACHE / "with_spans" / "dataset" def test_gpt_data_with_spans(): - get_test_dataset(prefix=DATASET_PREFIX.with_name("with_spans"), max_spans=5) + get_test_dataset(prefix=_DATASET_PREFIX_SPANS, max_spans=5) dataset = get_dataset_config( { "type": "memmap", diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index d2810d12..b9e4d248 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -5,10 +5,12 @@ import numpy as np import pytest +from fast_llm.data.dataset.gpt.config import GPTIndexedDatasetConfig from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, GPTMemmapDatasetPreparatorConfig from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator +from fast_llm.utils import Assert def get_preparator(output_path: str, dataset_path_name: str) -> GPTMemmapDatasetPreparator: @@ -71,3 +73,96 @@ def test_absent_metadata_local(): ): get_preparator(local_folder, dataset_folder)._save_croissant_metadata() assert not (pathlib.Path(local_folder) / "croissant.json").is_file() + + +DATASET_DICT_0 = { + "type": "mock_memmap", + "num_documents": 500, + "num_tokens_per_document": 300, +} +DATASET_DICT_1 = { + "type": "mock_memmap", + "num_documents": 1500, + "num_tokens_per_document": 100, +} + + +def test_split_dataset(): + dataset_config_0 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_0.copy()) + config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( + [dataset_config_0], + {"training": 3, "validation": 1}, + ) + config = {key: value.to_serialized() for key, value in config.items()} + + Assert.eq( + config, + { + "training": { + "type": "slice", + "dataset": dataset_config_0.to_serialized(), + "begin": 0, + "end": 0.75, + }, + "validation": { + "type": "slice", + "dataset": dataset_config_0.to_serialized(), + "begin": 0.75, + "end": 1, + }, + }, + ) + + +def test_split_datasets_0(): + dataset_config_0 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_0.copy()) + dataset_config_1 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_1.copy()) + config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( + [dataset_config_0, dataset_config_1], + {"training": 1, "validation": 1}, + ) + config = {key: value.to_serialized() for key, value in config.items()} + + Assert.eq( + config, + { + "training": dataset_config_0.to_serialized(), + "validation": dataset_config_1.to_serialized(), + }, + ) + + +def test_split_datasets_1(): + dataset_config_0 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_0.copy()) + dataset_config_1 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_1.copy()) + config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( + [dataset_config_0, dataset_config_1], + {"training": 3, "validation": 1}, + ) + config = {key: value.to_serialized() for key, value in config.items()} + + Assert.eq( + config, + { + "training": { + "type": "blended", + "name": "blended", + "datasets": [ + dataset_config_0.to_serialized(), + { + "type": "slice", + "dataset": dataset_config_1.to_serialized(), + "begin": 0, + "end": 0.5, + }, + ], + "weights": [2 / 3, 1 / 3], + }, + "validation": { + "type": "slice", + "dataset": dataset_config_1.to_serialized(), + "begin": 0.5, + "end": 1, + }, + }, + )