From 904d1e140797e4901fdbc98e3ddcfb08cb3f2241 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 26 Feb 2024 17:07:58 -0500 Subject: [PATCH 01/89] Implemented baseline LoRA peft for one Nvidia GPU. --- configs/config.yaml | 8 ++ examples/example_lora.py | 153 ++++++++++++++++++++++++++++++++++ vectorlm/utils/model_utils.py | 25 +++++- 3 files changed, 184 insertions(+), 2 deletions(-) create mode 100644 examples/example_lora.py diff --git a/configs/config.yaml b/configs/config.yaml index 2778f9c..906efee 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -1,6 +1,14 @@ model: /model-weights/Llama-2-7b-chat-hf enable_wandb_logging: True +lora_peft_config: + task_type: CAUSAL_LM + inference_mode: False + r: 8 + lora_alpha: 32 + lora_dropout: 0.1 + + wandb_config: project: MedGPT name: Llama-2-7B-chat diff --git a/examples/example_lora.py b/examples/example_lora.py new file mode 100644 index 0000000..3ac2878 --- /dev/null +++ b/examples/example_lora.py @@ -0,0 +1,153 @@ +# Renamed from examples/llama_example.py +import argparse +import math +import os +import sys +from argparse import Namespace + +import torch +import torch.distributed as dist +from torch.distributed.fsdp.fully_sharded_data_parallel import \ + FullyShardedDataParallel as FSDP +from torch.optim import AdamW +from tqdm import tqdm +from transformers import set_seed +from transformers.models.llama.modeling_llama import LlamaDecoderLayer + +from vectorlm.dataset import Dataset +from vectorlm.trainer import Trainer +from vectorlm.utils.data_utils import Config +from vectorlm.utils.misc_utils import cleanup, setup, wandb_setup +from vectorlm.utils.model_utils import (hook_activation_checkpointing, + initialize_lora_model_and_tokenizer, + shard_model_manual) +from vectorlm.utils.optimizer_utils import get_custom_scheduler +from vectorlm.utils.save_utils import save_consolidated_model + + +def parse_args() -> Namespace: + """Parse command-line arguments. + + Returns + ------- + The parsed arguments. + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--yaml_path", + default="configs/config.yaml", + required=False, + ) + return parser.parse_args() + + +def main(config: Config) -> None: + """Define the main calling function.""" + training_args = config.train_parameters + + # set a seed + set_seed(training_args.seed) + + # set CUDA related dependencies + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + print(f"Rank: {rank}, World size: {world_size}") + if dist.is_initialized(): + torch.cuda.set_device(local_rank) + torch.cuda.empty_cache() + + # setup wandb + if rank == 0: + wandb_setup(config, **config.wandb_config) + dist.barrier() + + # load model and tokenizer + state_dict_path = getattr(config, "state_dict", None) + + model, tokenizer = initialize_lora_model_and_tokenizer( + config.model, + training_args.use_mp, + training_args.use_flash_attention, + training_args.max_seq_len, + config.lora_peft_config, + ) + + if training_args.use_activation_checkpointing: + hook_activation_checkpointing(model, LlamaDecoderLayer) + + # load dataset + dataset = Dataset( + config=config.dataset, + tokenizer=tokenizer, + ) + + # instantiate trainer + trainer = Trainer( + config=training_args, + enable_wandb_logging=config.enable_wandb_logging, + original_dataset_length=dataset.original_length, + ) + + # load optimizer + optimizer = AdamW( + model.parameters(), + **training_args.optimizer, + ) + + # load lr scheduler + lr_scheduler = get_custom_scheduler( + training_args.lr_scheduler_type, + optimizer, + math.ceil( + trainer.num_update_steps_per_epoch * training_args.warmup_ratio, + ), + trainer.max_steps, + ) + + trainer.prepare_trainer( + model, + tokenizer, + dataset, + optimizer, + lr_scheduler, + ) + + # Checkpoint check. Always call before training. + # If no checkpoint, it returns 0. + checkpointed_epoch = trainer.find_checkpoint(training_args.output_dir) + + for epoch in range(checkpointed_epoch, training_args.epochs): + trainer.model.train() + train_dl_iterator = iter(dataset.train_dataloader) + for _ in tqdm( + range(len(dataset.train_dataloader)), + disable=rank != 0, + file=sys.__stdout__, + ): + batch = next(train_dl_iterator) + trainer.step(batch, epoch) + + if epoch == training_args.epochs - 1: + hf_save_dir = os.path.join(training_args.output_dir, "final-model") + else: + hf_save_dir = os.path.join( + training_args.output_dir, + "checkpoints", + f"epoch_{epoch}", + "end-epoch-model", + ) + save_consolidated_model(trainer.model, hf_save_dir, rank) + if rank == 0: + tokenizer.save_pretrained(hf_save_dir) + + dataset.reset_dataloaders() + + +if __name__ == "__main__": + args = parse_args() + config = Config(yaml_path=args.yaml_path) + setup(config.train_parameters.output_dir) + main(config) + cleanup() diff --git a/vectorlm/utils/model_utils.py b/vectorlm/utils/model_utils.py index 09d5c27..0924c0b 100644 --- a/vectorlm/utils/model_utils.py +++ b/vectorlm/utils/model_utils.py @@ -1,11 +1,11 @@ from __future__ import annotations import functools -from typing import Any +from typing import Any, Dict, Optional, Tuple import torch import torch.distributed as dist -from peft import PeftConfig, PeftModel +from peft import LoraConfig, PeftConfig, PeftModel, TaskType, get_peft_model from torch import nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( CheckpointImpl, @@ -25,6 +25,27 @@ ) +def initialize_lora_model_and_tokenizer( + path: str, + use_mp: bool, + use_fa: bool, + max_seq_len: int, + peft_config_dict: Dict[str, Any], +) -> tuple[PeftModel, PreTrainedTokenizer]: + """ + Initialize lora peft configuration for a non-lora model. + """ + model, tokenizer = load_model_and_tokenizer(path, use_mp, use_fa, max_seq_len) + + # Replace task type string in config with TaskType member. + task_type_str = peft_config_dict["task_type"] + task_type = getattr(TaskType, task_type_str) + lora_config = LoraConfig(**{**peft_config_dict, "task_type": task_type}) + + lora_model = get_peft_model(model, lora_config) + return lora_model, tokenizer + + def load_peft_model_and_tokenizer( path: str, use_mp: bool, From 2ace67ecb63bc74d26e1f48dc0b83e6e6f23174f Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Tue, 27 Feb 2024 16:50:06 -0500 Subject: [PATCH 02/89] Added support for saving lora adapters. Added support for non-fsdp models. --- vectorlm/trainer.py | 20 +++++++++++++++++--- vectorlm/utils/save_utils.py | 10 ++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/vectorlm/trainer.py b/vectorlm/trainer.py index 173c892..3954509 100644 --- a/vectorlm/trainer.py +++ b/vectorlm/trainer.py @@ -238,15 +238,28 @@ def train_step(self, batch: dict[str, torch.Tensor], epoch: int) -> float: ids = batch.pop("id").to(torch.cuda.current_device()) batch["input_ids"] = batch["input_ids"].type(torch.LongTensor) batch["labels"] = batch["labels"].type(torch.LongTensor) + batch = {k: v.to(torch.cuda.current_device()) for k, v in batch.items()} self.dataset.update_processed_ids(ids) if (self.tr_step + 1) % self.gas != self.gas - 1: - # no need to sync while accumulating gradients - with self.model.no_sync(): + if hasattr(self.model, "no_sync"): + # fsdp: no need to sync while accumulating gradients + with self.model.no_sync(): + out = self.model(**batch) + tr_step_loss = out.loss + (tr_step_loss / self.gas).backward() + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.config.max_grad_norm + ) + else: + # non-fsdp out = self.model(**batch) tr_step_loss = out.loss (tr_step_loss / self.gas).backward() - self.model.clip_grad_norm_(self.config.max_grad_norm) + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.config.max_grad_norm + ) + else: # next forward / backward pass will be synced dist.barrier() @@ -283,6 +296,7 @@ def eval_step(self, epoch: int) -> float: batch.pop("id") batch["input_ids"] = batch["input_ids"].type(torch.LongTensor) batch["labels"] = batch["labels"].type(torch.LongTensor) + batch = {k: v.to(torch.cuda.current_device()) for k, v in batch.items()} out = self.model(**batch) eval_loss += out.loss gathered_eval_loss = _gather(eval_loss.reshape(1)).mean().item() diff --git a/vectorlm/utils/save_utils.py b/vectorlm/utils/save_utils.py index aef9993..4a666a0 100644 --- a/vectorlm/utils/save_utils.py +++ b/vectorlm/utils/save_utils.py @@ -4,6 +4,7 @@ import re import torch +from peft.peft_model import PeftModel from torch import nn from torch.distributed.fsdp import ( FullStateDictConfig, # general model non-sharded, non-flattened params @@ -116,6 +117,11 @@ def save_model(model: nn.Module, output_dir: str, rank: int) -> None: os.makedirs(output_dir, exist_ok=True) weights_name = f"model_rank{rank}.bin" output_model_file = os.path.join(output_dir, weights_name) + + if isinstance(model, PeftModel): + model.save_pretrained(output_model_file) + return + with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): print(f"Saving model to {output_model_file}") state_dict = model.state_dict() @@ -155,6 +161,10 @@ def save_consolidated_model( save_dir: The checkpointing directory. rank: The worker's rank. """ + if isinstance(model, PeftModel): + model.save_pretrained(save_dir) + return + os.makedirs(save_dir, exist_ok=True) cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) save_path = os.path.join(save_dir, "model.bin") From a25e667ad64aba5c63852e03d73f274dc5a35e7c Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Thu, 29 Feb 2024 09:43:05 -0500 Subject: [PATCH 03/89] save_utils: added support for non-FSDP optimizers. trainer: replaced clip_grad_norm_ with nn.utils.clip_grad_norm_ for lora compatibility. --- vectorlm/trainer.py | 4 +++- vectorlm/utils/save_utils.py | 24 ++++++++++++++++++------ 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/vectorlm/trainer.py b/vectorlm/trainer.py index 3954509..0325468 100644 --- a/vectorlm/trainer.py +++ b/vectorlm/trainer.py @@ -266,7 +266,9 @@ def train_step(self, batch: dict[str, torch.Tensor], epoch: int) -> float: out = self.model(**batch) tr_step_loss = out.loss (tr_step_loss / self.gas).backward() - self.model.clip_grad_norm_(self.config.max_grad_norm) + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.config.max_grad_norm + ) self.optimizer.step() if isinstance(self.lr_scheduler, ReduceLROnPlateau): self.lr_scheduler.step(self.metric) diff --git a/vectorlm/utils/save_utils.py b/vectorlm/utils/save_utils.py index 4a666a0..dd1bbea 100644 --- a/vectorlm/utils/save_utils.py +++ b/vectorlm/utils/save_utils.py @@ -192,12 +192,24 @@ def save_optimizer( opt_name = f"optimizer_rank{rank}.bin" output_optimizer_file = os.path.join(output_dir, opt_name) opt_cfg = LocalOptimStateDictConfig(offload_to_cpu=True) - with FSDP.state_dict_type( - model, - StateDictType.LOCAL_STATE_DICT, - optim_state_dict_config=opt_cfg, - ): - opt_state = FSDP.optim_state_dict(model, optimizer) + + try: + with FSDP.state_dict_type( + model, + StateDictType.LOCAL_STATE_DICT, + optim_state_dict_config=opt_cfg, + ): + opt_state = FSDP.optim_state_dict(model, optimizer) + + print(f"Saving optimizer state to {output_optimizer_file}") + torch.save(opt_state, output_optimizer_file) + print(f"Optimizer state saved to {output_optimizer_file}") + + except AttributeError: + # One GPU only. Optimizer isn't sharded. + opt_state = optimizer.state_dict() + print("Optimizer state is retrieved as non-sharded") + print(f"Saving optimizer state to {output_optimizer_file}") torch.save(opt_state, output_optimizer_file) print(f"Optimizer state saved to {output_optimizer_file}") From 65a2dbf41802193bc470d20ac07f7d9d83b13039 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Thu, 29 Feb 2024 09:44:33 -0500 Subject: [PATCH 04/89] example_lora: highlighted current lora (non-fsdp) limitations. --- examples/example_lora.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/examples/example_lora.py b/examples/example_lora.py index 3ac2878..1761585 100644 --- a/examples/example_lora.py +++ b/examples/example_lora.py @@ -7,8 +7,6 @@ import torch import torch.distributed as dist -from torch.distributed.fsdp.fully_sharded_data_parallel import \ - FullyShardedDataParallel as FSDP from torch.optim import AdamW from tqdm import tqdm from transformers import set_seed @@ -18,9 +16,10 @@ from vectorlm.trainer import Trainer from vectorlm.utils.data_utils import Config from vectorlm.utils.misc_utils import cleanup, setup, wandb_setup -from vectorlm.utils.model_utils import (hook_activation_checkpointing, - initialize_lora_model_and_tokenizer, - shard_model_manual) +from vectorlm.utils.model_utils import ( + hook_activation_checkpointing, + initialize_lora_model_and_tokenizer, +) from vectorlm.utils.optimizer_utils import get_custom_scheduler from vectorlm.utils.save_utils import save_consolidated_model @@ -74,6 +73,10 @@ def main(config: Config) -> None: config.lora_peft_config, ) + # One GPU only. + assert dist.get_world_size() == 1 + model = model.cuda() + if training_args.use_activation_checkpointing: hook_activation_checkpointing(model, LlamaDecoderLayer) @@ -114,9 +117,9 @@ def main(config: Config) -> None: lr_scheduler, ) - # Checkpoint check. Always call before training. - # If no checkpoint, it returns 0. - checkpointed_epoch = trainer.find_checkpoint(training_args.output_dir) + # TODO: support restoring LoRA fine-tuning + trainer.dataset.setup_dataloaders() + checkpointed_epoch = 0 for epoch in range(checkpointed_epoch, training_args.epochs): trainer.model.train() From ed4c84f76346e83a187750197f01f58c37431de3 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Thu, 29 Feb 2024 09:51:37 -0500 Subject: [PATCH 05/89] Added instructions on LoRA on one GPU. --- README.md | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a327e6e..08d9657 100644 --- a/README.md +++ b/README.md @@ -63,9 +63,15 @@ We have provided an example script to show what a regular workflow would look li At the end of training, a consolidated model will be saved under your output directory as a `.bin` file. You can simply just run [`vectorlm/utils/convert_to_hf.py`](vectorlm/utils/convert_to_hf.py) to convert it to the regular HuggingFace model format. The script uses the main config file to determine save locations. +### Example: LoRA on one GPU + +We provide an additional example of parameter-efficient fine-tuning (PEFT) using LoRA on one NVIDIA GPU. Use the [`examples/launch_lora_one_gpu.sh`](examples/launch_lora_one_gpu.sh) to launch your job on the cluster. + +This [slide](https://docs.google.com/presentation/d/1ju7nItD0Xvnq_w5g25w91SpKnkpGWkOSRrjp8N-2TYM/edit?usp=sharing) provides more detail about the LoRA implementation in VectorLM, as well as challenges related to integrating LoRA with FSDP. + ## Roadmap -- PEFT methods (LoRA). +- PEFT methods (LoRA + FSDP). # Contributors -Adil Asif, Ziwen Han, John Willes. +Adil Asif, Ziwen Han, John Willes, Jacob-Junqi Tian. From 5a723927e9cc534c890c0a20a7e2392519136379 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Thu, 29 Feb 2024 09:54:09 -0500 Subject: [PATCH 06/89] Added example script for launching lora. --- examples/launch_lora_one_gpu.sh | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 examples/launch_lora_one_gpu.sh diff --git a/examples/launch_lora_one_gpu.sh b/examples/launch_lora_one_gpu.sh new file mode 100644 index 0000000..4390781 --- /dev/null +++ b/examples/launch_lora_one_gpu.sh @@ -0,0 +1,26 @@ +#!/bin/bash +#SBATCH --job-name=llama7b-2-lora +#SBATCH --nodes=1 +#SBATCH --mem=0 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-gpu=6 +#SBATCH --gres=gpu:1 +#SBATCH --output=llama-2-7b-lora.%j.out +#SBATCH --error=llama-2-7b-lora.%j.err +#SBATCH --partition=a100 +#SBATCH --qos=your_assigned_qos # CHANGE +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=3-00 + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export NCCL_DEBUG=WARN +export NCCL_DEBUG_SUBSYS=WARN + +# export TORCH_DISTRIBUTED_DEBUG=DETAIL # Uncomment these flags for debugging communication +# export TORCH_CPP_LOG_LEVEL=INFO +export LOGLEVEL=INFO +export PYTHONFAULTHANDLER=1 +# export CUDA_LAUNCH_BLOCKING=0 + +torchrun --nnodes=1 --nproc-per-node=1 example_lora.py --yaml_path configs/config-lora.yaml From e176ac829fc787b81fdc739dae1a8e84ecbff66d Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Thu, 29 Feb 2024 09:54:37 -0500 Subject: [PATCH 07/89] Revised instructions on LoRA on one GPU. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 08d9657..baadbd3 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ At the end of training, a consolidated model will be saved under your output dir We provide an additional example of parameter-efficient fine-tuning (PEFT) using LoRA on one NVIDIA GPU. Use the [`examples/launch_lora_one_gpu.sh`](examples/launch_lora_one_gpu.sh) to launch your job on the cluster. -This [slide](https://docs.google.com/presentation/d/1ju7nItD0Xvnq_w5g25w91SpKnkpGWkOSRrjp8N-2TYM/edit?usp=sharing) provides more detail about the LoRA implementation in VectorLM, as well as challenges related to integrating LoRA with FSDP. +This [slide deck](https://docs.google.com/presentation/d/1ju7nItD0Xvnq_w5g25w91SpKnkpGWkOSRrjp8N-2TYM/edit?usp=sharing) provides more detail about the LoRA implementation in VectorLM, as well as challenges related to integrating LoRA with FSDP. ## Roadmap - PEFT methods (LoRA + FSDP). From 2d869b09d152fba8020068ffdede0b5746fad447 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Wed, 6 Mar 2024 15:56:54 -0500 Subject: [PATCH 08/89] Implemented LoRA FSDP. Also see https://github.com/facebookresearch/llama-recipes/blob/674b37ee66f59a7845cbc3868948f4d7fa69c679/src/llama_recipes/utils/fsdp_utils.py#L9 --- README.md | 17 +++++---- configs/config-lora.yaml | 66 +++++++++++++++++++++++++++++++++++ examples/example_lora.py | 11 ++++-- examples/launch_lora.sh | 26 ++++++++++++++ vectorlm/utils/model_utils.py | 32 ++++++++++++++--- 5 files changed, 136 insertions(+), 16 deletions(-) create mode 100644 configs/config-lora.yaml create mode 100644 examples/launch_lora.sh diff --git a/README.md b/README.md index baadbd3..c742825 100644 --- a/README.md +++ b/README.md @@ -28,8 +28,9 @@ It is heavily recommended that you use Flash Attention-2, please follow the inst ## Introduction VectorLM is a training package built upon HuggingFace models and PyTorch Fully Sharded Data Parallelism. The package has been built around throughput optimizations. It is targeted at largely simplifying the workflow to setup distributed schemes while training **medium-sized** models in **resource-constrained** environments. This is especially true for academic clusters where powerful GPUs are available, but are bottlenecked by interconnectivity. Thus, there are two goals of this light-weight package: -* Use simple sharding strategies. FSDP is a great option for medium-sized model training. It is well maintained by the PyTorch team. -* Employ several optimization techniques to make scaling to larger models possible whilst minimizing memory usage and communication volume. As a result, we are able to efficiently dense finetune LLMs of sizes up to 13B parameters on the Vector cluster. + +- Use simple sharding strategies. FSDP is a great option for medium-sized model training. It is well maintained by the PyTorch team. +- Employ several optimization techniques to make scaling to larger models possible whilst minimizing memory usage and communication volume. As a result, we are able to efficiently dense finetune LLMs of sizes up to 13B parameters on the Vector cluster.
What is FSDP? @@ -42,6 +43,7 @@ Our package is designed for lightweight operations and is not intended for train
## Global Configuration + The central configuration that is used across data preprocessing, training, and dataset loading is under [`configs/config.yaml`](configs/config.yaml). All arguments, as well as recommendations, are documented under [`docs/config.md`](docs/config.md). ## Data Preprocessing @@ -54,8 +56,8 @@ We implement several training optimizations that can be reviewed under [`docs/tr ### Main Classes -* [`Dataset`](vectorlm/dataset.py): It loads the training and test sets as processed by data processing script above. It also sets the dataloaders and shards them across devices. -* [`Trainer`](vectorlm/trainer.py): The main trainer class. It contains the model, optimizer, LR scheduler, and dataloaders. It also performs the training and evaluation steps as well as state checkpointing. +- [`Dataset`](vectorlm/dataset.py): It loads the training and test sets as processed by data processing script above. It also sets the dataloaders and shards them across devices. +- [`Trainer`](vectorlm/trainer.py): The main trainer class. It contains the model, optimizer, LR scheduler, and dataloaders. It also performs the training and evaluation steps as well as state checkpointing. ### Example: Llama-2 @@ -63,14 +65,11 @@ We have provided an example script to show what a regular workflow would look li At the end of training, a consolidated model will be saved under your output directory as a `.bin` file. You can simply just run [`vectorlm/utils/convert_to_hf.py`](vectorlm/utils/convert_to_hf.py) to convert it to the regular HuggingFace model format. The script uses the main config file to determine save locations. -### Example: LoRA on one GPU - -We provide an additional example of parameter-efficient fine-tuning (PEFT) using LoRA on one NVIDIA GPU. Use the [`examples/launch_lora_one_gpu.sh`](examples/launch_lora_one_gpu.sh) to launch your job on the cluster. +### Example: LoRA FSDP -This [slide deck](https://docs.google.com/presentation/d/1ju7nItD0Xvnq_w5g25w91SpKnkpGWkOSRrjp8N-2TYM/edit?usp=sharing) provides more detail about the LoRA implementation in VectorLM, as well as challenges related to integrating LoRA with FSDP. +We provide an additional example of parameter-efficient fine-tuning (PEFT) using LoRA and FSDP. Use the [`examples/launch_lora.sh`](examples/launch_lora.sh) to launch your job on the cluster. ## Roadmap -- PEFT methods (LoRA + FSDP). # Contributors diff --git a/configs/config-lora.yaml b/configs/config-lora.yaml new file mode 100644 index 0000000..04e9046 --- /dev/null +++ b/configs/config-lora.yaml @@ -0,0 +1,66 @@ +model: facebook/opt-125m +enable_wandb_logging: True + +lora_peft_config: + task_type: CAUSAL_LM + inference_mode: False + r: 8 + lora_alpha: 32 + lora_dropout: 0.1 + +wandb_config: + project: vector-lm-verify + name: opt-125m-lora + +train_parameters: + output_dir: data/model/opt-125m-gsm8k-lora + max_seq_len: 1024 + epochs: 1 + seed: 11 + + # Sharding strategy + sharding_strategy: FULL_SHARD + + # Memory + use_mp: True + use_activation_checkpointing: True + use_flash_attention: True + + # Gradient norm clipping + max_grad_norm: 1 + gradient_accumulation_steps: 4 + + # Optimizer + optimizer: + lr: 2.0e-5 + weight_decay: 0.1 + betas: [0.9, 0.95] + eps: 1.0e-5 + + # Scheduler + lr_scheduler_type: cosine + warmup_ratio: 0.05 + + # Checkpointing + checkpointing_enabled: True + logging_steps: 500 + save_frequency: 0.25 + +dataset: + ignore_index: -100 + eval_bs: 8 + train_bs: 8 + train_ds: data/processed/gsm8k-question/train + eval_ds: data/processed/gsm8k-question/test + +dataset_preprocess: + ignore_index: -100 + dataset_format: hf + data_field: question + packing_type: partial + add_bos_eos_tokens: True + from_disk: True + load_path: data/raw/gsm8k + split: train + save_path: data/processed/gsm8k-question/train + truncate: False diff --git a/examples/example_lora.py b/examples/example_lora.py index 1761585..2f3a61e 100644 --- a/examples/example_lora.py +++ b/examples/example_lora.py @@ -19,6 +19,7 @@ from vectorlm.utils.model_utils import ( hook_activation_checkpointing, initialize_lora_model_and_tokenizer, + shard_model, ) from vectorlm.utils.optimizer_utils import get_custom_scheduler from vectorlm.utils.save_utils import save_consolidated_model @@ -73,9 +74,13 @@ def main(config: Config) -> None: config.lora_peft_config, ) - # One GPU only. - assert dist.get_world_size() == 1 - model = model.cuda() + model = shard_model( + model, + LlamaDecoderLayer, + training_args.use_mp, + training_args.use_activation_checkpointing, + training_args.sharding_strategy, + ) if training_args.use_activation_checkpointing: hook_activation_checkpointing(model, LlamaDecoderLayer) diff --git a/examples/launch_lora.sh b/examples/launch_lora.sh new file mode 100644 index 0000000..76f68e1 --- /dev/null +++ b/examples/launch_lora.sh @@ -0,0 +1,26 @@ +#!/bin/bash +#SBATCH --job-name=llama7b-2 +#SBATCH --nodes=1 +#SBATCH --mem=0 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-gpu=6 +#SBATCH --gres=gpu:4 +#SBATCH --output=llama-2-7b.%j.out +#SBATCH --error=llama-2-7b.%j.err +#SBATCH --partition=a100 +#SBATCH --qos=your_assigned_qos # CHANGE +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=3-00 + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export NCCL_DEBUG=WARN +export NCCL_DEBUG_SUBSYS=WARN + +# export TORCH_DISTRIBUTED_DEBUG=DETAIL # Uncomment these flags for debugging communication +# export TORCH_CPP_LOG_LEVEL=INFO +export LOGLEVEL=INFO +export PYTHONFAULTHANDLER=1 +# export CUDA_LAUNCH_BLOCKING=0 + +torchrun --nnodes=1 --nproc-per-node=2 example_lora.py --yaml_path configs/config-lora.yaml diff --git a/vectorlm/utils/model_utils.py b/vectorlm/utils/model_utils.py index 0924c0b..0f6d045 100644 --- a/vectorlm/utils/model_utils.py +++ b/vectorlm/utils/model_utils.py @@ -16,7 +16,12 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import ( FullyShardedDataParallel as FSDP, ) -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.distributed.fsdp.wrap import ( + transformer_auto_wrap_policy, + _or_policy, + lambda_auto_wrap_policy, +) + from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -93,6 +98,7 @@ def load_peft_model_and_tokenizer( ) return peft_model, tokenizer + def load_model_and_tokenizer( path: str, use_mp: bool, @@ -173,9 +179,25 @@ def fsdp_config( ) ret_dict["mixed_precision"] = mp_policy + # See https://github.com/facebookresearch/llama-recipes/blob/674b37ee6/src/llama_recipes/utils/fsdp_utils.py#L9 + def lambda_policy_fn(module): + if ( + len(list(module.named_children())) == 0 + and getattr(module, "weight", None) is not None + and module.weight.requires_grad + ): + return True + return False + + lambda_policy = functools.partial( + lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn + ) + transformer_wrap_policy = functools.partial( + transformer_auto_wrap_policy, transformer_layer_cls=[layer_to_wrap] + ) + auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={layer_to_wrap}, + _or_policy, policies=[lambda_policy, transformer_wrap_policy] ) sharding_strategy = getattr(ShardingStrategy, strategy) @@ -239,5 +261,7 @@ def hook_activation_checkpointing( check_fn = lambda submodule: isinstance(submodule, layer) apply_activation_checkpointing( - model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn, + model, + checkpoint_wrapper_fn=non_reentrant_wrapper, + check_fn=check_fn, ) From dc098d63d755c2d3c97e1fec2c3a93e9a7bd4a65 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Wed, 6 Mar 2024 16:01:06 -0500 Subject: [PATCH 09/89] Reverted automatic formatter changes in README.md --- README.md | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index c742825..d98e86a 100644 --- a/README.md +++ b/README.md @@ -28,9 +28,8 @@ It is heavily recommended that you use Flash Attention-2, please follow the inst ## Introduction VectorLM is a training package built upon HuggingFace models and PyTorch Fully Sharded Data Parallelism. The package has been built around throughput optimizations. It is targeted at largely simplifying the workflow to setup distributed schemes while training **medium-sized** models in **resource-constrained** environments. This is especially true for academic clusters where powerful GPUs are available, but are bottlenecked by interconnectivity. Thus, there are two goals of this light-weight package: - -- Use simple sharding strategies. FSDP is a great option for medium-sized model training. It is well maintained by the PyTorch team. -- Employ several optimization techniques to make scaling to larger models possible whilst minimizing memory usage and communication volume. As a result, we are able to efficiently dense finetune LLMs of sizes up to 13B parameters on the Vector cluster. +* Use simple sharding strategies. FSDP is a great option for medium-sized model training. It is well maintained by the PyTorch team. +* Employ several optimization techniques to make scaling to larger models possible whilst minimizing memory usage and communication volume. As a result, we are able to efficiently dense finetune LLMs of sizes up to 13B parameters on the Vector cluster.
What is FSDP? @@ -43,7 +42,6 @@ Our package is designed for lightweight operations and is not intended for train
## Global Configuration - The central configuration that is used across data preprocessing, training, and dataset loading is under [`configs/config.yaml`](configs/config.yaml). All arguments, as well as recommendations, are documented under [`docs/config.md`](docs/config.md). ## Data Preprocessing @@ -56,8 +54,8 @@ We implement several training optimizations that can be reviewed under [`docs/tr ### Main Classes -- [`Dataset`](vectorlm/dataset.py): It loads the training and test sets as processed by data processing script above. It also sets the dataloaders and shards them across devices. -- [`Trainer`](vectorlm/trainer.py): The main trainer class. It contains the model, optimizer, LR scheduler, and dataloaders. It also performs the training and evaluation steps as well as state checkpointing. +* [`Dataset`](vectorlm/dataset.py): It loads the training and test sets as processed by data processing script above. It also sets the dataloaders and shards them across devices. +* [`Trainer`](vectorlm/trainer.py): The main trainer class. It contains the model, optimizer, LR scheduler, and dataloaders. It also performs the training and evaluation steps as well as state checkpointing. ### Example: Llama-2 @@ -69,8 +67,6 @@ At the end of training, a consolidated model will be saved under your output dir We provide an additional example of parameter-efficient fine-tuning (PEFT) using LoRA and FSDP. Use the [`examples/launch_lora.sh`](examples/launch_lora.sh) to launch your job on the cluster. -## Roadmap - # Contributors Adil Asif, Ziwen Han, John Willes, Jacob-Junqi Tian. From 5a1fd76c18fe10eb1203ae3400053eae8b2df6d7 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Wed, 6 Mar 2024 16:22:03 -0500 Subject: [PATCH 10/89] Eliminated non-FSDP logic from save_utils. Set model path to local copy of llama-2-7b in example config. --- configs/config-lora.yaml | 2 +- vectorlm/utils/save_utils.py | 36 +++++++----------------------------- 2 files changed, 8 insertions(+), 30 deletions(-) diff --git a/configs/config-lora.yaml b/configs/config-lora.yaml index 04e9046..affbe68 100644 --- a/configs/config-lora.yaml +++ b/configs/config-lora.yaml @@ -1,4 +1,4 @@ -model: facebook/opt-125m +model: /model-weights/Llama-2-7b-chat-hf enable_wandb_logging: True lora_peft_config: diff --git a/vectorlm/utils/save_utils.py b/vectorlm/utils/save_utils.py index dd1bbea..36998af 100644 --- a/vectorlm/utils/save_utils.py +++ b/vectorlm/utils/save_utils.py @@ -4,7 +4,6 @@ import re import torch -from peft.peft_model import PeftModel from torch import nn from torch.distributed.fsdp import ( FullStateDictConfig, # general model non-sharded, non-flattened params @@ -117,11 +116,6 @@ def save_model(model: nn.Module, output_dir: str, rank: int) -> None: os.makedirs(output_dir, exist_ok=True) weights_name = f"model_rank{rank}.bin" output_model_file = os.path.join(output_dir, weights_name) - - if isinstance(model, PeftModel): - model.save_pretrained(output_model_file) - return - with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): print(f"Saving model to {output_model_file}") state_dict = model.state_dict() @@ -161,10 +155,6 @@ def save_consolidated_model( save_dir: The checkpointing directory. rank: The worker's rank. """ - if isinstance(model, PeftModel): - model.save_pretrained(save_dir) - return - os.makedirs(save_dir, exist_ok=True) cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) save_path = os.path.join(save_dir, "model.bin") @@ -192,24 +182,12 @@ def save_optimizer( opt_name = f"optimizer_rank{rank}.bin" output_optimizer_file = os.path.join(output_dir, opt_name) opt_cfg = LocalOptimStateDictConfig(offload_to_cpu=True) - - try: - with FSDP.state_dict_type( - model, - StateDictType.LOCAL_STATE_DICT, - optim_state_dict_config=opt_cfg, - ): - opt_state = FSDP.optim_state_dict(model, optimizer) - - print(f"Saving optimizer state to {output_optimizer_file}") - torch.save(opt_state, output_optimizer_file) - print(f"Optimizer state saved to {output_optimizer_file}") - - except AttributeError: - # One GPU only. Optimizer isn't sharded. - opt_state = optimizer.state_dict() - print("Optimizer state is retrieved as non-sharded") - + with FSDP.state_dict_type( + model, + StateDictType.LOCAL_STATE_DICT, + optim_state_dict_config=opt_cfg, + ): + opt_state = FSDP.optim_state_dict(model, optimizer) print(f"Saving optimizer state to {output_optimizer_file}") torch.save(opt_state, output_optimizer_file) print(f"Optimizer state saved to {output_optimizer_file}") @@ -286,4 +264,4 @@ def load_scheduler( print(f"Loading scheduler state from {input_scheduler_file}") state_dict = torch.load(input_scheduler_file) scheduler.load_state_dict(state_dict) - print(f"Scheduler state loaded from {input_scheduler_file}") + print(f"Scheduler state loaded from {input_scheduler_file}") \ No newline at end of file From 7e187bc20e9584cb8c753253b11a53a6612242eb Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Wed, 6 Mar 2024 16:26:21 -0500 Subject: [PATCH 11/89] Moved lora config out of example config.yaml. --- configs/config-lora.yaml | 16 ++++++++-------- configs/config.yaml | 8 -------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/configs/config-lora.yaml b/configs/config-lora.yaml index affbe68..e9eca95 100644 --- a/configs/config-lora.yaml +++ b/configs/config-lora.yaml @@ -9,11 +9,11 @@ lora_peft_config: lora_dropout: 0.1 wandb_config: - project: vector-lm-verify - name: opt-125m-lora + project: MedGPT + name: Llama-2-7B-chat train_parameters: - output_dir: data/model/opt-125m-gsm8k-lora + output_dir: your/output/dir max_seq_len: 1024 epochs: 1 seed: 11 @@ -50,17 +50,17 @@ dataset: ignore_index: -100 eval_bs: 8 train_bs: 8 - train_ds: data/processed/gsm8k-question/train - eval_ds: data/processed/gsm8k-question/test + train_ds: your/train/ds + eval_ds: your/eval/ds dataset_preprocess: ignore_index: -100 dataset_format: hf - data_field: question + data_field: text packing_type: partial add_bos_eos_tokens: True from_disk: True - load_path: data/raw/gsm8k + load_path: your/unprocessed/dataset split: train - save_path: data/processed/gsm8k-question/train + save_path: dir/to/save/processed/dataset truncate: False diff --git a/configs/config.yaml b/configs/config.yaml index 906efee..2778f9c 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -1,14 +1,6 @@ model: /model-weights/Llama-2-7b-chat-hf enable_wandb_logging: True -lora_peft_config: - task_type: CAUSAL_LM - inference_mode: False - r: 8 - lora_alpha: 32 - lora_dropout: 0.1 - - wandb_config: project: MedGPT name: Llama-2-7B-chat From 3eea331eecf7e2c028a250709e85699896957d7d Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 11 Mar 2024 12:44:43 -0400 Subject: [PATCH 12/89] Implemented LoRA benchmarking logic for worker. --- .gitignore | 3 +- example_lora.py | 162 +++++++++++++++++++ launch_lora_benchmark.py | 58 +++++++ lora_benchmark.py | 264 +++++++++++++++++++++++++++++++ vectorlm/tests/__init__.py | 0 vectorlm/tests/test_modelling.py | 256 ++++++++++++++++++++++++++++++ vectorlm/trainer.py | 49 ++++-- vectorlm/utils/model_utils.py | 71 +++++---- 8 files changed, 818 insertions(+), 45 deletions(-) create mode 100644 example_lora.py create mode 100644 launch_lora_benchmark.py create mode 100644 lora_benchmark.py create mode 100644 vectorlm/tests/__init__.py create mode 100644 vectorlm/tests/test_modelling.py diff --git a/.gitignore b/.gitignore index fdf00fd..f79c4ac 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ **/*.sh __pycache__/ wandb/ -build/ \ No newline at end of file +build/ +data/ diff --git a/example_lora.py b/example_lora.py new file mode 100644 index 0000000..a6b298d --- /dev/null +++ b/example_lora.py @@ -0,0 +1,162 @@ +# Renamed from examples/llama_example.py +import argparse +import math +import os +import sys +from argparse import Namespace + +import torch +import torch.distributed as dist +from torch.optim import AdamW +from tqdm import tqdm +from transformers import set_seed +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from peft.utils.other import fsdp_auto_wrap_policy + +from vectorlm.dataset import Dataset +from vectorlm.trainer import Trainer +from vectorlm.utils.data_utils import Config +from vectorlm.utils.misc_utils import cleanup, setup, wandb_setup +from vectorlm.utils.model_utils import ( + hook_activation_checkpointing, + initialize_lora_model_and_tokenizer, + shard_model, +) +from vectorlm.utils.optimizer_utils import get_custom_scheduler +from vectorlm.utils.save_utils import save_consolidated_model + + +def parse_args() -> Namespace: + """Parse command-line arguments. + + Returns + ------- + The parsed arguments. + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--yaml_path", + default="configs/config.yaml", + required=False, + ) + return parser.parse_args() + + +def main(config: Config) -> None: + """Define the main calling function.""" + training_args = config.train_parameters + + # set a seed + set_seed(training_args.seed) + + # set CUDA related dependencies + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + print(f"Rank: {rank}, World size: {world_size}") + if dist.is_initialized(): + torch.cuda.set_device(local_rank) + torch.cuda.empty_cache() + + # setup wandb + if rank == 0: + wandb_setup(config, **config.wandb_config) + dist.barrier() + + # load model and tokenizer + state_dict_path = getattr(config, "state_dict", None) + + model, tokenizer = initialize_lora_model_and_tokenizer( + config.model, + training_args.use_mp, + training_args.use_flash_attention, + training_args.max_seq_len, + config.lora_peft_config, + ) + + model = shard_model( + model, + LlamaDecoderLayer, + training_args.use_mp, + training_args.use_activation_checkpointing, + training_args.sharding_strategy, + ) + + if training_args.use_activation_checkpointing: + hook_activation_checkpointing(model, LlamaDecoderLayer) + + # load dataset + dataset = Dataset( + config=config.dataset, + tokenizer=tokenizer, + ) + + # instantiate trainer + trainer = Trainer( + config=training_args, + enable_wandb_logging=config.enable_wandb_logging, + original_dataset_length=dataset.original_length, + ) + + # load optimizer + optimizer = AdamW( + model.parameters(), + **training_args.optimizer, + ) + + # load lr scheduler + lr_scheduler = get_custom_scheduler( + training_args.lr_scheduler_type, + optimizer, + math.ceil( + trainer.num_update_steps_per_epoch * training_args.warmup_ratio, + ), + trainer.max_steps, + ) + + trainer.prepare_trainer( + model, + tokenizer, + dataset, + optimizer, + lr_scheduler, + ) + + # TODO: support restoring LoRA fine-tuning + trainer.dataset.setup_dataloaders() + checkpointed_epoch = trainer.find_checkpoint(training_args.output_dir) + + for epoch in range(checkpointed_epoch, training_args.epochs): + trainer.model.train() + train_dl_iterator = iter(dataset.train_dataloader) + for _ in tqdm( + range(len(dataset.train_dataloader)), + disable=rank != 0, + file=sys.__stdout__, + ): + batch = next(train_dl_iterator) + trainer.step(batch, epoch) + + if epoch == training_args.epochs - 1: + hf_save_dir = os.path.join(training_args.output_dir, "final-model") + else: + hf_save_dir = os.path.join( + training_args.output_dir, + "checkpoints", + f"epoch_{epoch}", + "end-epoch-model", + ) + save_consolidated_model(trainer.model, hf_save_dir, rank) + if rank == 0: + tokenizer.save_pretrained(hf_save_dir) + + dataset.reset_dataloaders() + + +if __name__ == "__main__": + args = parse_args() + config = Config(yaml_path=args.yaml_path) + setup(config.train_parameters.output_dir) + main(config) + cleanup() diff --git a/launch_lora_benchmark.py b/launch_lora_benchmark.py new file mode 100644 index 0000000..e7c2f60 --- /dev/null +++ b/launch_lora_benchmark.py @@ -0,0 +1,58 @@ +""" +Create SLURM jobs running the LoRA benchmark. +""" + +from typing import List +import itertools +import subprocess +import time + +model_list = [ + "/model-weights/" + model_name + for model_name in [ + "opt-350m", + "Llama-2-7b-hf", + "Llama-2-13b-hf", + "Mistral-7B-v0.1", + "t5-xl-lm-adapt", + ] +] + +slurm_flags_options = { + "nodes": [1], + "mem": [0], + "ntasks-per-node": [1], + "cpus-per-gpu": [6], + "gres": ["gpu:{}".format(n + 1) for n in range(1)], + "partition": ["t4v2", "a40", "a100"], +} + +slurm_flags_extra = {"time": "00:30:00", "qos": "scavenger"} + +slurm_pos_args_options = [["examples/launch_lora_benchmark.sh"], model_list] +timestamp = int(time.time()) + +for index, (flag_values, pos_args_option) in enumerate( + zip( + itertools.product(*(slurm_flags_options.values())), + itertools.product(*slurm_pos_args_options), + ) +): + args: List[str] = ["sbatch"] + + extra_flags = { + **slurm_flags_extra, + "output": "data/output/{}.{}.out".format(timestamp, index), + "error": "data/output/{}.{}.out".format(timestamp, index), + "job-name": "bench-{}-{}".format(timestamp, index), + } + + keys = list(slurm_flags_options.keys()) + list(extra_flags.keys()) + values = list(flag_values) + list(extra_flags.values()) + for key, value in zip(keys, values): + arg = ("--{}".format(key), str(value)) + args.extend(arg) + + args.extend(pos_args_option) + + print(" ".join(args)) diff --git a/lora_benchmark.py b/lora_benchmark.py new file mode 100644 index 0000000..0791908 --- /dev/null +++ b/lora_benchmark.py @@ -0,0 +1,264 @@ +# Renamed from examples/llama_example.py +import argparse +import contextlib +import json +import math +import os +import sys +import time +from argparse import Namespace +from typing import Any, Dict, Optional + +import torch +import torch.distributed as dist +from torch.optim import AdamW +from tqdm import tqdm +from transformers import set_seed +from peft.utils.other import fsdp_auto_wrap_policy + +from vectorlm.dataset import Dataset +from vectorlm.trainer import Trainer +from vectorlm.utils.data_utils import Config +from vectorlm.utils.misc_utils import cleanup, setup, wandb_setup +from vectorlm.utils.model_utils import ( + hook_activation_checkpointing, + initialize_lora_model_and_tokenizer, + shard_model, + get_submodule_by_pattern, +) +from vectorlm.utils.optimizer_utils import get_custom_scheduler +from vectorlm.utils.save_utils import save_consolidated_model + + +def parse_args() -> Namespace: + """Parse command-line arguments. + + Returns + ------- + The parsed arguments. + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--yaml_path", + default="configs/config.yaml", + required=False, + ) + parser.add_argument( + "--model_name", + required=True, + ) + return parser.parse_args() + + +# unix timestamp +launch_time = time.time() +os.makedirs("data/benchmark", exist_ok=True) +output_path = "data/benchmark/{}.jsonl".format(launch_time) + + +def write_metrics(metric_name: str, value: Optional[Any] = None) -> None: + """ + Write metric and time elapsed to output file. + Write to disk only if process rank is 0. + + Params: + metric_name: string indicating type of metric + value: JSON-serializable value, + or None to log only time elapsed + """ + time_since_launch = time.time() - launch_time + output_dict = { + "name": metric_name, + "time_since_launch": time_since_launch, + "value": value, + } + output_line = json.dumps(output_dict) + + if dist.get_rank() == 0: + with open(output_path, "a") as output_file: + output_file.write(output_line + "\n") + + +@contextlib.contextmanager +def track_time(task_name: str, extra_info: Dict[str, Any] = {}): + start_time = time.time() + try: + yield + finally: + time_elapsed = time.time() - start_time + write_metrics(task_name, {"time_elapsed": time_elapsed, **extra_info}) + + +def get_device_info() -> Dict[str, str | int]: + """ + Get CUDA info as a dict. + + Returns: + Dict including device_name and world size + """ + return dict( + device_name=torch.cuda.get_device_name(), + local_rank=int(os.environ["LOCAL_RANK"]), + rank=int(os.environ["RANK"]), + world_size=int(os.environ["WORLD_SIZE"]), + ) + + +def get_is_flash_attention_supported() -> bool: + """ + Returns: + Whether Flash Attention is supported based on + the given CUDA device capability. + """ + version_major, _ = torch.cuda.get_device_capability() + return version_major >= 8 + + +def get_slurm_env() -> Dict[str, str]: + """ + Returns a dictionary of all env var starting with "SLURM_". + """ + output = { + key: value for key, value in os.environ.items() if key.startswith("SLURM_") + } + return output + + +def main(config: Config, model_name: str) -> None: + """Define the main calling function.""" + write_metrics("model_name", model_name) + write_metrics("config", {**config.__dict__}) + write_metrics("device_info", get_device_info()) + write_metrics("slurm_info", get_slurm_env()) + + training_args = config.train_parameters + + # set a seed + set_seed(training_args.seed) + + # set CUDA related dependencies + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + with track_time("dist_init"): + print(f"Rank: {rank}, World size: {world_size}") + if dist.is_initialized(): + torch.cuda.set_device(local_rank) + torch.cuda.empty_cache() + + # setup wandb + if rank == 0: + wandb_setup(config, **config.wandb_config) + dist.barrier() + + # load model and tokenizer + state_dict_path = getattr(config, "state_dict", None) + + with track_time("model_load"): + model, tokenizer = initialize_lora_model_and_tokenizer( + model_name, + training_args.use_mp, + get_is_flash_attention_supported(), + training_args.max_seq_len, + config.lora_peft_config, + ) + decoder_layer_module = get_submodule_by_pattern(model, r"DecoderLayer$") + + if decoder_layer_module is None: + track_time("decoder_layer_module_is_none") + raise ValueError("decoder_layer_module is None.") + + with track_time("model_shard"): + model = shard_model( + model, + decoder_layer_module, + training_args.use_mp, + training_args.use_activation_checkpointing, + training_args.sharding_strategy, + ) + + with track_time("set_activation_checkpointing"): + if training_args.use_activation_checkpointing: + hook_activation_checkpointing(model, decoder_layer_module) + + # load dataset + with track_time("dataset_load"): + dataset = Dataset( + config=config.dataset, + tokenizer=tokenizer, + ) + + # instantiate trainer + trainer = Trainer( + config=training_args, + enable_wandb_logging=config.enable_wandb_logging, + original_dataset_length=dataset.original_length, + timer_handle=track_time, + ) + + # load optimizer + with track_time("optimizer_initialize"): + optimizer = AdamW( + model.parameters(), + **training_args.optimizer, + ) + + # load lr scheduler + lr_scheduler = get_custom_scheduler( + training_args.lr_scheduler_type, + optimizer, + math.ceil( + trainer.num_update_steps_per_epoch * training_args.warmup_ratio, + ), + trainer.max_steps, + ) + + trainer.prepare_trainer( + model, + tokenizer, + dataset, + optimizer, + lr_scheduler, + ) + + # TODO: support restoring LoRA fine-tuning + trainer.dataset.setup_dataloaders() + checkpointed_epoch = 0 + + for epoch in range(checkpointed_epoch, training_args.epochs): + trainer.model.train() + train_dl_iterator = iter(dataset.train_dataloader) + for _ in tqdm( + range(len(dataset.train_dataloader)), + disable=rank != 0, + file=sys.__stdout__, + ): + batch = next(train_dl_iterator) + trainer.step(batch, epoch) + + if epoch == training_args.epochs - 1: + with track_time("save_final"): + hf_save_dir = os.path.join(training_args.output_dir, "final-model") + else: + with track_time("save_checkpoint"): + hf_save_dir = os.path.join( + training_args.output_dir, + "checkpoints", + f"epoch_{epoch}", + "end-epoch-model", + ) + with track_time("save_consolidated"): + save_consolidated_model(trainer.model, hf_save_dir, rank) + if rank == 0: + tokenizer.save_pretrained(hf_save_dir) + + dataset.reset_dataloaders() + + +if __name__ == "__main__": + args = parse_args() + config = Config(yaml_path=args.yaml_path) + setup(config.train_parameters.output_dir) + main(config, args.model_name) + cleanup() diff --git a/vectorlm/tests/__init__.py b/vectorlm/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vectorlm/tests/test_modelling.py b/vectorlm/tests/test_modelling.py new file mode 100644 index 0000000..31e5c37 --- /dev/null +++ b/vectorlm/tests/test_modelling.py @@ -0,0 +1,256 @@ +""" +Test model loading, sharding, and forward/backward. +""" + +from collections import Counter, defaultdict + +import pytest +import torch +import torch.distributed as dist +from torch import nn +from torch.optim import AdamW +from torch.distributed.fsdp import ShardingStrategy +from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullyShardedDataParallel as FSDP, +) +from transformers.models.opt.modeling_opt import OPTDecoderLayer + +from vectorlm.utils.model_utils import ( + hook_activation_checkpointing, + initialize_lora_model_and_tokenizer, + load_model_and_tokenizer, + shard_model, + get_submodule_by_pattern, +) + + +@pytest.fixture(scope="session") +def setup_and_teardown_torch_process_group(): + # Setup + dist.init_process_group( + backend="nccl", + init_method="tcp://localhost:25567", + rank=0, + world_size=1, + ) + + yield + + # Teardown + dist.destroy_process_group() + + +@pytest.fixture() +def lora_peft_config(): + """ + Example peft config_dict for LoRA. + """ + return { + "task_type": "CAUSAL_LM", + "inference_mode": False, + "r": 8, + "lora_alpha": 32, + "lora_dropout": 0.1, + } + + +@pytest.fixture(scope="session") +def base_model(): + model, tokenizer = load_model_and_tokenizer("facebook/opt-125m", True, False, 1024) + return model + + +@pytest.fixture() +def lora_model(lora_peft_config): + lora_model, tokenizer = initialize_lora_model_and_tokenizer( + "facebook/opt-125m", True, False, 1024, lora_peft_config + ) + return lora_model + + +@pytest.fixture() +def base_model_sharded(base_model, setup_and_teardown_torch_process_group): + model_sharded = shard_model(base_model, OPTDecoderLayer, True, True, "FULL_SHARD") + return model_sharded + + +@pytest.fixture() +def lora_model_sharded(lora_model, setup_and_teardown_torch_process_group): + model_sharded = shard_model(lora_model, OPTDecoderLayer, True, True, "FULL_SHARD") + return FSDP(model_sharded, device_id=torch.cuda.current_device()) + + +@pytest.fixture() +def optimizer_lora_sharded(lora_model_sharded): + optimizer = AdamW(lora_model_sharded.parameters()) + return optimizer + + +@pytest.fixture() +def batch(): + batch = { + "input_ids": torch.zeros((1, 12)), + "labels": torch.zeros((1, 12)), + "attention_mask": torch.ones((1, 12)), + } + + batch = {k: v.type(torch.LongTensor) for k, v in batch.items()} + batch = {k: v.to(torch.device(0)) for k, v in batch.items()} + + return batch + + +def test_load_model_and_tokenizer(): + """ + Test load base model and tokenizer. + """ + model, tokenizer = load_model_and_tokenizer("facebook/opt-125m", True, True, 1024) + + print("type(model): {}".format(type(model))) + + +def test_load_lora_model_and_tokenizer(lora_peft_config): + """ + Test load base model and tokenizer. + """ + lora_model, tokenizer = initialize_lora_model_and_tokenizer( + "facebook/opt-125m", True, True, 1024, lora_peft_config + ) + + print("type(lora_model): {}".format(type(lora_model))) + + +def test_match_submodule_by_pattern(base_model): + """ + Test selecting DecoderLayer class from container. + """ + + submodule = get_submodule_by_pattern(base_model, r"DecoderLayer$") + assert submodule == OPTDecoderLayer + + +def test_partition_base_model(base_model, setup_and_teardown_torch_process_group): + """ + Test partitioning base model (no lora/peft). + """ + base_model = shard_model(base_model, OPTDecoderLayer, True, True, "FULL_SHARD") + + output_text = [] + for parameter_name, parameter in base_model.named_parameters(): + requires_grad = parameter.requires_grad + output_text.append("{}\t{}".format(requires_grad, parameter_name)) + + with open("output_base.txt", "w") as output_file: + output_file.write("\n".join(output_text)) + + +def test_get_module_types(lora_model_sharded): + """ + Output type of each module. + """ + output_text = [] + print(lora_model_sharded) + + for module_name, module in lora_model_sharded.named_modules(): + output_text.append("{}\t{}".format(module_name, type(module))) + + with open("module_types.txt", "w") as output_file: + output_file.write("\n".join(output_text)) + + +def test_partition_lora_model(lora_model, setup_and_teardown_torch_process_group): + """ + Test partitioning lora peft model. + """ + # # lora.Linear is a submodule of OPTDecoderLayer. + # for index, module in enumerate(lora_model.modules()): + # print(index, module) + + model_sharded = shard_model( + lora_model, nn.modules.linear.Linear, True, True, "FULL_SHARD" + ) + model_sharded = FSDP( + model_sharded, use_orig_params=True, device_id=torch.cuda.current_device() + ) + + requires_grad_counters = defaultdict(Counter) + + output_text = [] + reference_device = None + for parameter_name, parameter in model_sharded.named_parameters(): + requires_grad = parameter.requires_grad + requires_grad_counters[requires_grad][parameter_name] += 1 + output_text.append( + "{}\t{}\t{}".format(requires_grad, parameter.device, parameter_name) + ) + + if reference_device is not None: + assert parameter.device == reference_device + + reference_device = parameter.device + + with open("output.txt", "w") as output_file: + output_file.write("\n".join(output_text)) + + +def test_forward_base(base_model_sharded, batch): + """ + Test forward run of sharded base model. + """ + base_model_sharded.train() + output = base_model_sharded(**batch) + loss = output.loss + loss.backward() + print(output) + print(loss) + print(loss.shape) + + +def test_forward_lora(lora_model_sharded, batch): + """ + Test forward run of sharded lora model. + """ + lora_model_sharded.train() + output = lora_model_sharded(**batch) + loss = output.loss + print(output) + print(loss) + print(loss.shape) + + +def test_forward_backward_lora(lora_model_sharded, batch): + """ + Test forward and backward run of sharded lora model. + """ + lora_model_sharded.train() + output = lora_model_sharded(**batch) + loss = output.loss + + loss.backward() + + print(output) + print(loss) + print(loss.shape) + + +def test_train_lora(lora_model_sharded, optimizer_lora_sharded, batch): + """ + Test N optimization steps on the LoRA sharded model. + """ + optimizer = optimizer_lora_sharded + model = lora_model_sharded + loss_values = [] + for _ in range(7 * 13): + output = model(**batch) + loss = output.loss + loss.backward() + + optimizer.step() + optimizer.zero_grad() + + loss_values.append(loss.cpu().item()) + print(loss.cpu().item()) + + print(loss_values) + assert loss_values[-1] < loss_values[0] diff --git a/vectorlm/trainer.py b/vectorlm/trainer.py index 0325468..6fc8302 100644 --- a/vectorlm/trainer.py +++ b/vectorlm/trainer.py @@ -1,5 +1,6 @@ from __future__ import annotations +from contextlib import contextmanager import math import os from typing import Any @@ -27,6 +28,15 @@ ) +@contextmanager +def timer_placeholder(task_name: str): + try: + yield # start code block + finally: + # run before exiting + return + + class Trainer: """Main trainer class. @@ -49,11 +59,13 @@ class Trainer: saving_steps: An integer for how often we save. """ - def __init__(self, - config: Config, - enable_wandb_logging: bool, - original_dataset_length: int, - ) -> None: + def __init__( + self, + config: Config, + enable_wandb_logging: bool, + original_dataset_length: int, + timer_handle=timer_placeholder, + ) -> None: """Initialize the Trainer class. Args: @@ -62,6 +74,7 @@ def __init__(self, enable_wandb_logging: Whether to enable wandb logging. original_dataset_length: The length of the original dataset (divided by the batch size). + timer_handle: Optional context manager for profiling. """ self.config = config self.gas = config.gradient_accumulation_steps @@ -80,6 +93,7 @@ def __init__(self, self.num_update_steps_per_epoch = None self.max_steps = None self.saving_steps = None + self.timer_handle = timer_handle self._post_process(original_dataset_length) def _post_process(self, ds_orig_length: int) -> None: @@ -145,9 +159,16 @@ def save_checkpoint(self, epoch: int) -> None: ) if rank == 0: save_metadata(save_dir, meta_dict) - save_model(self.model, save_dir, rank) - save_optimizer(self.optimizer, self.model, save_dir, rank) - save_scheduler(self.lr_scheduler, save_dir, rank) + + with self.timer_handle("trainer_save_model"): + save_model(self.model, save_dir, rank) + + with self.timer_handle("trainer_save_optimizer"): + save_optimizer(self.optimizer, self.model, save_dir, rank) + + with self.timer_handle("train_save_scheduler"): + save_scheduler(self.lr_scheduler, save_dir, rank) + dist.barrier() def load_checkpoint(self, checkpoint_dir: str) -> int: @@ -218,7 +239,9 @@ def step( ): self.save_checkpoint(epoch) - train_loss = self.train_step(train_batch, epoch) + num_tokens = len(train_batch["input_ids"].flatten()) + with self.timer_handle("train_step", {"num_tokens": num_tokens}): + train_loss = self.train_step(train_batch, epoch) test_loss = None if self.tr_step % self.logging_steps == 0: @@ -297,10 +320,14 @@ def eval_step(self, epoch: int) -> float: with torch.no_grad(): batch.pop("id") batch["input_ids"] = batch["input_ids"].type(torch.LongTensor) + num_tokens = len(batch["input_ids"].flatten()) batch["labels"] = batch["labels"].type(torch.LongTensor) batch = {k: v.to(torch.cuda.current_device()) for k, v in batch.items()} - out = self.model(**batch) - eval_loss += out.loss + + with self.timer_handle("eval_step", {"num_tokens": num_tokens}): + out = self.model(**batch) + eval_loss += out.loss + gathered_eval_loss = _gather(eval_loss.reshape(1)).mean().item() mean_eval_loss = gathered_eval_loss / len(self.dataset.eval_dataloader) diff --git a/vectorlm/utils/model_utils.py b/vectorlm/utils/model_utils.py index 0f6d045..ca7a70a 100644 --- a/vectorlm/utils/model_utils.py +++ b/vectorlm/utils/model_utils.py @@ -1,11 +1,13 @@ from __future__ import annotations import functools +import re from typing import Any, Dict, Optional, Tuple import torch import torch.distributed as dist from peft import LoraConfig, PeftConfig, PeftModel, TaskType, get_peft_model +from peft.utils.other import fsdp_auto_wrap_policy from torch import nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( CheckpointImpl, @@ -16,11 +18,6 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import ( FullyShardedDataParallel as FSDP, ) -from torch.distributed.fsdp.wrap import ( - transformer_auto_wrap_policy, - _or_policy, - lambda_auto_wrap_policy, -) from transformers import ( AutoModelForCausalLM, @@ -148,17 +145,13 @@ def load_model_and_tokenizer( return model, tokenizer -def fsdp_config( - use_mp: bool, - layer_to_wrap: nn.Module, - strategy: str, -) -> dict[str, Any]: +def fsdp_config(use_mp: bool, model: nn.Module, strategy: str) -> dict[str, Any]: """Get FSDP config. Args: ---- use_mp: Whether to use mixed-precision. - layer_to_wrap: The layer we are wrapping using FSDP. + model_to_wrap: The HuggingFace model to wrap using FSDP. strategy: The sharding strategy to use. Returns: @@ -179,29 +172,9 @@ def fsdp_config( ) ret_dict["mixed_precision"] = mp_policy - # See https://github.com/facebookresearch/llama-recipes/blob/674b37ee6/src/llama_recipes/utils/fsdp_utils.py#L9 - def lambda_policy_fn(module): - if ( - len(list(module.named_children())) == 0 - and getattr(module, "weight", None) is not None - and module.weight.requires_grad - ): - return True - return False - - lambda_policy = functools.partial( - lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn - ) - transformer_wrap_policy = functools.partial( - transformer_auto_wrap_policy, transformer_layer_cls=[layer_to_wrap] - ) - - auto_wrap_policy = functools.partial( - _or_policy, policies=[lambda_policy, transformer_wrap_policy] - ) sharding_strategy = getattr(ShardingStrategy, strategy) - ret_dict["auto_wrap_policy"] = auto_wrap_policy + ret_dict["auto_wrap_policy"] = fsdp_auto_wrap_policy(model) ret_dict["sharding_strategy"] = sharding_strategy ret_dict["device_id"] = torch.cuda.current_device() return ret_dict @@ -228,7 +201,7 @@ def shard_model( ------- The sharded module with the requested configurations. """ - fsdp_cfg = fsdp_config(use_mp, layer_to_wrap, strategy) + fsdp_cfg = fsdp_config(use_mp, model, strategy) if dist.get_rank() == 0: print(f"FSDP config: {fsdp_cfg}") model = FSDP(model, **fsdp_cfg) @@ -265,3 +238,35 @@ def hook_activation_checkpointing( checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn, ) + + +def get_submodule_by_pattern( + module: nn.Module, pattern: str +) -> Optional[type[nn.Module]]: + """ + Return the first module.cls that matches pattern, + at least partially. + + With reference to get_module_class_from_name from HuggingFace + accelerate `FullyShardedDataParallelPlugin`. + + Args: + ----- + module: Layer container + pattern: regular expression string. + + Returns: + -------- + Matched layer (nn.Module) or None if not matched. + """ + modules_children = list(module.children()) + module_name = module.__class__.__name__ + if re.search(pattern, module_name) is not None: + return module.__class__ + elif len(modules_children) == 0: + return + else: + for child_module in modules_children: + module_class = get_submodule_by_pattern(child_module, pattern) + if module_class is not None: + return module_class From 906e4f3c7bcb88753cad795328fd24900245ddf8 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 11 Mar 2024 13:24:18 -0400 Subject: [PATCH 13/89] model_utils: Refactored get_lora_model to reduce interface width. (this method no longer wraps load_model_and_tokenizer) test_modelling: revised base model fixture scope since torch FSDP wrap is in-place. launch_benchmark: added confirmation before launching. --- lora_benchmark.py => benchmark.py | 10 ++-- ...h_lora_benchmark.py => launch_benchmark.py | 15 ++++-- vectorlm/tests/test_modelling.py | 53 ++++++++----------- vectorlm/utils/model_utils.py | 28 +++++----- 4 files changed, 52 insertions(+), 54 deletions(-) rename lora_benchmark.py => benchmark.py (96%) rename launch_lora_benchmark.py => launch_benchmark.py (82%) diff --git a/lora_benchmark.py b/benchmark.py similarity index 96% rename from lora_benchmark.py rename to benchmark.py index 0791908..033a35d 100644 --- a/lora_benchmark.py +++ b/benchmark.py @@ -22,9 +22,10 @@ from vectorlm.utils.misc_utils import cleanup, setup, wandb_setup from vectorlm.utils.model_utils import ( hook_activation_checkpointing, - initialize_lora_model_and_tokenizer, + load_model_and_tokenizer, shard_model, get_submodule_by_pattern, + get_lora_model_from_base_model, ) from vectorlm.utils.optimizer_utils import get_custom_scheduler from vectorlm.utils.save_utils import save_consolidated_model @@ -154,15 +155,18 @@ def main(config: Config, model_name: str) -> None: # load model and tokenizer state_dict_path = getattr(config, "state_dict", None) + lora_peft_config = getattr(config, "state_dict", None) with track_time("model_load"): - model, tokenizer = initialize_lora_model_and_tokenizer( + model, tokenizer = load_model_and_tokenizer( model_name, training_args.use_mp, get_is_flash_attention_supported(), training_args.max_seq_len, - config.lora_peft_config, ) + if lora_peft_config is not None: + model = get_lora_model_from_base_model(model, lora_peft_config) + decoder_layer_module = get_submodule_by_pattern(model, r"DecoderLayer$") if decoder_layer_module is None: diff --git a/launch_lora_benchmark.py b/launch_benchmark.py similarity index 82% rename from launch_lora_benchmark.py rename to launch_benchmark.py index e7c2f60..cf19041 100644 --- a/launch_lora_benchmark.py +++ b/launch_benchmark.py @@ -7,6 +7,8 @@ import subprocess import time +from tqdm.auto import tqdm + model_list = [ "/model-weights/" + model_name for model_name in [ @@ -14,7 +16,6 @@ "Llama-2-7b-hf", "Llama-2-13b-hf", "Mistral-7B-v0.1", - "t5-xl-lm-adapt", ] ] @@ -23,7 +24,7 @@ "mem": [0], "ntasks-per-node": [1], "cpus-per-gpu": [6], - "gres": ["gpu:{}".format(n + 1) for n in range(1)], + "gres": ["gpu:{}".format(n + 1) for n in range(8)], "partition": ["t4v2", "a40", "a100"], } @@ -32,8 +33,9 @@ slurm_pos_args_options = [["examples/launch_lora_benchmark.sh"], model_list] timestamp = int(time.time()) +args_list: List[List[str]] = [] for index, (flag_values, pos_args_option) in enumerate( - zip( + itertools.product( itertools.product(*(slurm_flags_options.values())), itertools.product(*slurm_pos_args_options), ) @@ -54,5 +56,10 @@ args.extend(arg) args.extend(pos_args_option) - + args_list.append(args) print(" ".join(args)) + +input("\nPress ENTER to launch {} job(s)".format(len(args_list))) + +for args in tqdm(args_list): + subprocess.run(args) diff --git a/vectorlm/tests/test_modelling.py b/vectorlm/tests/test_modelling.py index 31e5c37..d9290b0 100644 --- a/vectorlm/tests/test_modelling.py +++ b/vectorlm/tests/test_modelling.py @@ -3,6 +3,7 @@ """ from collections import Counter, defaultdict +import re import pytest import torch @@ -17,15 +18,14 @@ from transformers.models.opt.modeling_opt import OPTDecoderLayer from vectorlm.utils.model_utils import ( - hook_activation_checkpointing, - initialize_lora_model_and_tokenizer, + get_lora_model_from_base_model, load_model_and_tokenizer, shard_model, get_submodule_by_pattern, ) -@pytest.fixture(scope="session") +@pytest.fixture() def setup_and_teardown_torch_process_group(): # Setup dist.init_process_group( @@ -55,17 +55,15 @@ def lora_peft_config(): } -@pytest.fixture(scope="session") +@pytest.fixture() def base_model(): model, tokenizer = load_model_and_tokenizer("facebook/opt-125m", True, False, 1024) return model @pytest.fixture() -def lora_model(lora_peft_config): - lora_model, tokenizer = initialize_lora_model_and_tokenizer( - "facebook/opt-125m", True, False, 1024, lora_peft_config - ) +def lora_model(base_model, lora_peft_config): + lora_model = get_lora_model_from_base_model(base_model, lora_peft_config) return lora_model @@ -110,18 +108,7 @@ def test_load_model_and_tokenizer(): print("type(model): {}".format(type(model))) -def test_load_lora_model_and_tokenizer(lora_peft_config): - """ - Test load base model and tokenizer. - """ - lora_model, tokenizer = initialize_lora_model_and_tokenizer( - "facebook/opt-125m", True, True, 1024, lora_peft_config - ) - - print("type(lora_model): {}".format(type(lora_model))) - - -def test_match_submodule_by_pattern(base_model): +def test_match_submodule_by_pattern(base_model, lora_model): """ Test selecting DecoderLayer class from container. """ @@ -129,6 +116,9 @@ def test_match_submodule_by_pattern(base_model): submodule = get_submodule_by_pattern(base_model, r"DecoderLayer$") assert submodule == OPTDecoderLayer + submodule = get_submodule_by_pattern(base_model, r"DecoderLayer$") + assert submodule == OPTDecoderLayer + def test_partition_base_model(base_model, setup_and_teardown_torch_process_group): """ @@ -159,28 +149,25 @@ def test_get_module_types(lora_model_sharded): output_file.write("\n".join(output_text)) -def test_partition_lora_model(lora_model, setup_and_teardown_torch_process_group): +def test_fsdp_lora_model_require_grad( + lora_model_sharded, setup_and_teardown_torch_process_group +): """ Test partitioning lora peft model. """ - # # lora.Linear is a submodule of OPTDecoderLayer. - # for index, module in enumerate(lora_model.modules()): - # print(index, module) - - model_sharded = shard_model( - lora_model, nn.modules.linear.Linear, True, True, "FULL_SHARD" - ) - model_sharded = FSDP( - model_sharded, use_orig_params=True, device_id=torch.cuda.current_device() - ) requires_grad_counters = defaultdict(Counter) output_text = [] reference_device = None - for parameter_name, parameter in model_sharded.named_parameters(): + for parameter_name, parameter in lora_model_sharded.named_parameters(): requires_grad = parameter.requires_grad requires_grad_counters[requires_grad][parameter_name] += 1 + if re.search("lora_[A|B]", parameter_name) is not None: + assert requires_grad == True, parameter_name + else: + assert requires_grad == False, parameter_name + output_text.append( "{}\t{}\t{}".format(requires_grad, parameter.device, parameter_name) ) @@ -190,6 +177,8 @@ def test_partition_lora_model(lora_model, setup_and_teardown_torch_process_group reference_device = parameter.device + # # Uncomment line below to see all parameter names. + # print(requires_grad_counters) with open("output.txt", "w") as output_file: output_file.write("\n".join(output_text)) diff --git a/vectorlm/utils/model_utils.py b/vectorlm/utils/model_utils.py index ca7a70a..de36839 100644 --- a/vectorlm/utils/model_utils.py +++ b/vectorlm/utils/model_utils.py @@ -27,25 +27,23 @@ ) -def initialize_lora_model_and_tokenizer( - path: str, - use_mp: bool, - use_fa: bool, - max_seq_len: int, - peft_config_dict: Dict[str, Any], -) -> tuple[PeftModel, PreTrainedTokenizer]: +def get_lora_model_from_base_model( + base_model: nn.Module, peft_config_dict: Dict +) -> PeftModel: """ - Initialize lora peft configuration for a non-lora model. - """ - model, tokenizer = load_model_and_tokenizer(path, use_mp, use_fa, max_seq_len) + Initialize lora peft configuration from a non-lora model. - # Replace task type string in config with TaskType member. + Args: + ----- + base_model: HuggingFace Transformer model to wrap. + peft_config_dict: configuration from yaml config file. + """ task_type_str = peft_config_dict["task_type"] task_type = getattr(TaskType, task_type_str) lora_config = LoraConfig(**{**peft_config_dict, "task_type": task_type}) - lora_model = get_peft_model(model, lora_config) - return lora_model, tokenizer + lora_model = get_peft_model(base_model, lora_config) + return lora_model def load_peft_model_and_tokenizer( @@ -182,7 +180,7 @@ def fsdp_config(use_mp: bool, model: nn.Module, strategy: str) -> dict[str, Any] def shard_model( model: nn.Module, - layer_to_wrap: nn.Module, + layer_to_wrap: type[nn.Module], use_mp: bool, use_activation_checkpointing: bool, strategy: str, @@ -217,7 +215,7 @@ def shard_model( def hook_activation_checkpointing( model: nn.Module, - layer: nn.Module, + layer: type[nn.Module], ) -> None: """Set activation checkpointing. From 0c415359740056fb6b99fce60c0fae08d0eb21c3 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 11 Mar 2024 13:27:40 -0400 Subject: [PATCH 14/89] test_modelling: moved text output to data/. --- .gitignore | 4 ++++ vectorlm/tests/test_modelling.py | 7 +++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index f79c4ac..05c42cc 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,7 @@ __pycache__/ wandb/ build/ data/ +**/*.pyc +/.cache +/.vscode +/data \ No newline at end of file diff --git a/vectorlm/tests/test_modelling.py b/vectorlm/tests/test_modelling.py index d9290b0..f3a7af4 100644 --- a/vectorlm/tests/test_modelling.py +++ b/vectorlm/tests/test_modelling.py @@ -128,10 +128,9 @@ def test_partition_base_model(base_model, setup_and_teardown_torch_process_group output_text = [] for parameter_name, parameter in base_model.named_parameters(): - requires_grad = parameter.requires_grad output_text.append("{}\t{}".format(requires_grad, parameter_name)) - with open("output_base.txt", "w") as output_file: + with open("data/output_base.txt", "w") as output_file: output_file.write("\n".join(output_text)) @@ -145,7 +144,7 @@ def test_get_module_types(lora_model_sharded): for module_name, module in lora_model_sharded.named_modules(): output_text.append("{}\t{}".format(module_name, type(module))) - with open("module_types.txt", "w") as output_file: + with open("data/module_types.txt", "w") as output_file: output_file.write("\n".join(output_text)) @@ -179,7 +178,7 @@ def test_fsdp_lora_model_require_grad( # # Uncomment line below to see all parameter names. # print(requires_grad_counters) - with open("output.txt", "w") as output_file: + with open("data/output.txt", "w") as output_file: output_file.write("\n".join(output_text)) From f24d2fa8ce3d427c8d029ca4cfa4f0d90f3e33d3 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 11 Mar 2024 13:28:25 -0400 Subject: [PATCH 15/89] added example yaml config for lora benchmarking. --- configs/config-lora-benchmark.yaml | 66 ++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 configs/config-lora-benchmark.yaml diff --git a/configs/config-lora-benchmark.yaml b/configs/config-lora-benchmark.yaml new file mode 100644 index 0000000..02812e9 --- /dev/null +++ b/configs/config-lora-benchmark.yaml @@ -0,0 +1,66 @@ +enable_wandb_logging: True + +lora_peft_config: + task_type: CAUSAL_LM + inference_mode: False + r: 8 + lora_alpha: 32 + lora_dropout: 0.1 + +wandb_config: + project: vector-lm-verify + name: benchmark-lora + +train_parameters: + output_dir: /tmp/lora-benchmark + max_seq_len: 128 + epochs: 1 + seed: 11 + + # Sharding strategy + sharding_strategy: FULL_SHARD + + # Memory + use_mp: True + use_activation_checkpointing: True + # use_flash_attention is automatically enabled + # for CUDA capability > 8.0 + + # Gradient norm clipping + max_grad_norm: 1 + gradient_accumulation_steps: 4 + + # Optimizer + optimizer: + lr: 2.0e-5 + weight_decay: 0.1 + betas: [0.9, 0.95] + eps: 1.0e-5 + + # Scheduler + lr_scheduler_type: cosine + warmup_ratio: 0.05 + + # Checkpointing + checkpointing_enabled: True + logging_steps: 500 + save_frequency: 0.25 + +dataset: + ignore_index: -100 + eval_bs: 8 + train_bs: 8 + train_ds: data/processed/gsm8k-question/train + eval_ds: data/processed/gsm8k-question/test + +dataset_preprocess: + ignore_index: -100 + dataset_format: hf + data_field: question + packing_type: partial + add_bos_eos_tokens: True + from_disk: True + load_path: data/raw/gsm8k + split: train + save_path: data/processed/gsm8k-question/train + truncate: False From 7d27d90307f1536eb9968fea29570cb4a8fe8d0d Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 11 Mar 2024 13:47:03 -0400 Subject: [PATCH 16/89] launch_benchmark: marked qos flag as optional. --- launch_benchmark.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/launch_benchmark.py b/launch_benchmark.py index cf19041..b9b9688 100644 --- a/launch_benchmark.py +++ b/launch_benchmark.py @@ -2,13 +2,19 @@ Create SLURM jobs running the LoRA benchmark. """ -from typing import List +import argparse import itertools import subprocess import time +from typing import List from tqdm.auto import tqdm +parser = argparse.ArgumentParser() +parser.add_argument("--qos", required=False) +cli_args = parser.parse_args() +qos_selected = cli_args.qos + model_list = [ "/model-weights/" + model_name for model_name in [ @@ -28,7 +34,7 @@ "partition": ["t4v2", "a40", "a100"], } -slurm_flags_extra = {"time": "00:30:00", "qos": "scavenger"} +slurm_flags_extra = {"time": "00:30:00", "qos": qos_selected} slurm_pos_args_options = [["examples/launch_lora_benchmark.sh"], model_list] timestamp = int(time.time()) @@ -52,8 +58,9 @@ keys = list(slurm_flags_options.keys()) + list(extra_flags.keys()) values = list(flag_values) + list(extra_flags.values()) for key, value in zip(keys, values): - arg = ("--{}".format(key), str(value)) - args.extend(arg) + if value is not None: + arg = ("--{}".format(key), str(value)) + args.extend(arg) args.extend(pos_args_option) args_list.append(args) From d22ea852dba2b65993ad27a5e9f66eb2dde1e3bc Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 11 Mar 2024 13:57:52 -0400 Subject: [PATCH 17/89] launch_benchmark: added option to limit number of jobs launched. --- launch_benchmark.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/launch_benchmark.py b/launch_benchmark.py index b9b9688..f54e758 100644 --- a/launch_benchmark.py +++ b/launch_benchmark.py @@ -12,8 +12,10 @@ parser = argparse.ArgumentParser() parser.add_argument("--qos", required=False) -cli_args = parser.parse_args() -qos_selected = cli_args.qos +parser.add_argument("--max_num_jobs", required=False) +launcher_args = parser.parse_args() +qos_selected = launcher_args.qos +max_num_jobs = launcher_args.max_num_jobs model_list = [ "/model-weights/" + model_name @@ -66,7 +68,10 @@ args_list.append(args) print(" ".join(args)) + if (max_num_jobs is not None) and index + 1 >= int(max_num_jobs): + break + input("\nPress ENTER to launch {} job(s)".format(len(args_list))) -for args in tqdm(args_list): +for args in tqdm(args_list, ncols=75): subprocess.run(args) From 84b953a27ded5e2cd79e10d876cbc8fc9633ea2e Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 11 Mar 2024 14:09:35 -0400 Subject: [PATCH 18/89] launch_benchmark: implemented torch profiler integration. --- benchmark.py | 112 ++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 84 insertions(+), 28 deletions(-) diff --git a/benchmark.py b/benchmark.py index 033a35d..8c50533 100644 --- a/benchmark.py +++ b/benchmark.py @@ -12,6 +12,7 @@ import torch import torch.distributed as dist from torch.optim import AdamW +from torch.profiler import ProfilerActivity from tqdm import tqdm from transformers import set_seed from peft.utils.other import fsdp_auto_wrap_policy @@ -125,6 +126,49 @@ def get_slurm_env() -> Dict[str, str]: return output +def parse_profiler_output( + profiler_output: torch.autograd.profiler.profile, +) -> Dict[str, Dict[str, str | float | int]]: + """ + Parse profiler_output to obtain dictionary of metrics. + + Returns: + Dictionary mapping event name to dictionary of metrics. + """ + key_average_event_list = profiler_output.key_averages() + output: Dict[str, Dict[str, str | float | int]] = {} + for evt in key_average_event_list: + if evt.trace_name is None: + continue + output[evt.trace_name] = { + "start": evt.time_range.start, + "elapsed": evt.time_range.elapsed_us(), + "args": ( + evt.thread + if not evt.is_remote + else f'" node_id:{evt.node_id}, thread_id:{evt.thread} "' + ), + } + + return output + + +def handle_profiler_trace(profiler_output: torch.autograd.profiler.profile): + """ + Log torch profile to disk. + This function is to be invoked as a callback for on_track_ready. + + Args: + ----- + profile: from Torch profiler. + """ + print(profiler_output) + key_average_event_list = profiler_output.key_averages() + write_metrics("profiler_table", key_average_event_list.table()) + parsed_output = parse_profiler_output(profiler_output) + write_metrics("profiler_output", parsed_output) + + def main(config: Config, model_name: str) -> None: """Define the main calling function.""" write_metrics("model_name", model_name) @@ -132,6 +176,10 @@ def main(config: Config, model_name: str) -> None: write_metrics("device_info", get_device_info()) write_metrics("slurm_info", get_slurm_env()) + profiler_schedule = torch.profiler.schedule( + skip_first=10, wait=5, warmup=1, active=3, repeat=2 + ) + training_args = config.train_parameters # set a seed @@ -230,34 +278,42 @@ def main(config: Config, model_name: str) -> None: trainer.dataset.setup_dataloaders() checkpointed_epoch = 0 - for epoch in range(checkpointed_epoch, training_args.epochs): - trainer.model.train() - train_dl_iterator = iter(dataset.train_dataloader) - for _ in tqdm( - range(len(dataset.train_dataloader)), - disable=rank != 0, - file=sys.__stdout__, - ): - batch = next(train_dl_iterator) - trainer.step(batch, epoch) - - if epoch == training_args.epochs - 1: - with track_time("save_final"): - hf_save_dir = os.path.join(training_args.output_dir, "final-model") - else: - with track_time("save_checkpoint"): - hf_save_dir = os.path.join( - training_args.output_dir, - "checkpoints", - f"epoch_{epoch}", - "end-epoch-model", - ) - with track_time("save_consolidated"): - save_consolidated_model(trainer.model, hf_save_dir, rank) - if rank == 0: - tokenizer.save_pretrained(hf_save_dir) - - dataset.reset_dataloaders() + # See pytorch.org/tutorials/recipes/recipes/profiler_recipe.html + with torch.profiler.profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=profiler_schedule, + on_trace_ready=handle_profiler_trace, + ) as profile_handle: + for epoch in range(checkpointed_epoch, training_args.epochs): + trainer.model.train() + train_dl_iterator = iter(dataset.train_dataloader) + for _ in tqdm( + # range(len(dataset.train_dataloader)), + range(7 * 13), + disable=rank != 0, + file=sys.__stdout__, + ): + batch = next(train_dl_iterator) + trainer.step(batch, epoch) + profile_handle.step() + + if epoch == training_args.epochs - 1: + with track_time("save_final"): + hf_save_dir = os.path.join(training_args.output_dir, "final-model") + else: + with track_time("save_checkpoint"): + hf_save_dir = os.path.join( + training_args.output_dir, + "checkpoints", + f"epoch_{epoch}", + "end-epoch-model", + ) + with track_time("save_consolidated"): + save_consolidated_model(trainer.model, hf_save_dir, rank) + if rank == 0: + tokenizer.save_pretrained(hf_save_dir) + + dataset.reset_dataloaders() if __name__ == "__main__": From e1cda073b9879da6bf4236f1c66e97a40b2f1110 Mon Sep 17 00:00:00 2001 From: Adil <47084919+adil-a@users.noreply.github.com> Date: Mon, 11 Mar 2024 15:03:06 -0400 Subject: [PATCH 19/89] Merged changes from low CPU memory usage feature (#6) into jjt/lora-benchmarking * added changes to implement low cpu mem usage feature * implemented new ruff linting changes and ran a fix across files --- configs/config.yaml | 5 +- docs/config.md | 1 + examples/llama_example.py | 5 ++ benchmark.py => profiling/benchmark.py | 27 +++++++- .../configs/lora-benchmark.yaml | 6 +- .../launch_benchmark.py | 0 pyproject.toml | 9 +-- vectorlm/dataset.py | 2 + vectorlm/tests/test_dataset.py | 39 +++++++++++ vectorlm/tests/test_modelling.py | 34 ++++++---- vectorlm/trainer.py | 11 +++- vectorlm/utils/convert_to_hf.py | 2 + vectorlm/utils/data_utils.py | 5 ++ vectorlm/utils/misc_utils.py | 1 + vectorlm/utils/model_utils.py | 65 ++++++++++++++++--- vectorlm/utils/optimizer_utils.py | 7 +- vectorlm/utils/save_utils.py | 11 ++++ 17 files changed, 196 insertions(+), 34 deletions(-) rename benchmark.py => profiling/benchmark.py (91%) rename configs/config-lora-benchmark.yaml => profiling/configs/lora-benchmark.yaml (92%) rename launch_benchmark.py => profiling/launch_benchmark.py (100%) create mode 100644 vectorlm/tests/test_dataset.py diff --git a/configs/config.yaml b/configs/config.yaml index 2778f9c..4ce2dde 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -1,4 +1,4 @@ -model: /model-weights/Llama-2-7b-chat-hf +model: /model-weights/Llama-2-7b-chat-hf/ enable_wandb_logging: True wandb_config: @@ -7,7 +7,7 @@ wandb_config: train_parameters: output_dir: your/output/dir - max_seq_len: 1024 + max_seq_len: 4096 epochs: 1 seed: 11 @@ -18,6 +18,7 @@ train_parameters: use_mp: True use_activation_checkpointing: True use_flash_attention: True + low_cpu_mem_usage: True # Gradient norm clipping max_grad_norm: 1 diff --git a/docs/config.md b/docs/config.md index db56806..0541c10 100644 --- a/docs/config.md +++ b/docs/config.md @@ -28,6 +28,7 @@ The key-value pairs stored under `wandb_config` are directly passed into the [`w * `use_mp`: Whether to use mixed precision. This is done using bf16. * `use_activation_checkpointing`: Whether to use activation checkpointing. This greatly reduces memory footprint as only a few intermediate activations as saved during the forward pass, and are then recomputed for the backward pass on the fly. However, the tradeoff between compute vs. memory usually makes this worth it. * `use_flash_attention`: Whether to use Flash Attention. If it is supported for your model in HuggingFace, you can enable this option. +* `low_cpu_mem_usage`: Whether to efficiently load the model. If enabled, the model weights are only loaded once on rank 0 and are broadcasted to the rest of the world from the main rank. It will prevent the CPU memory from exploding when loading large models (e.g. LLaMa-70B). ### Gradient diff --git a/examples/llama_example.py b/examples/llama_example.py index 9e9f4eb..d794740 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -26,6 +26,7 @@ def parse_args() -> Namespace: Returns ------- The parsed arguments. + """ parser = argparse.ArgumentParser() parser.add_argument( @@ -62,6 +63,8 @@ def main(config: Config) -> None: training_args.use_mp, training_args.use_flash_attention, training_args.max_seq_len, + local_rank, + training_args.low_cpu_mem_usage, ) model = shard_model( @@ -70,6 +73,8 @@ def main(config: Config) -> None: training_args.use_mp, training_args.use_activation_checkpointing, training_args.sharding_strategy, + local_rank, + training_args.low_cpu_mem_usage, ) # load dataset diff --git a/benchmark.py b/profiling/benchmark.py similarity index 91% rename from benchmark.py rename to profiling/benchmark.py index 8c50533..fb2ffe8 100644 --- a/benchmark.py +++ b/profiling/benchmark.py @@ -56,6 +56,7 @@ def parse_args() -> Namespace: launch_time = time.time() os.makedirs("data/benchmark", exist_ok=True) output_path = "data/benchmark/{}.jsonl".format(launch_time) +profiler_output_path = "data/trace/{}.json".format(launch_time) def write_metrics(metric_name: str, value: Optional[Any] = None) -> None: @@ -138,7 +139,8 @@ def parse_profiler_output( key_average_event_list = profiler_output.key_averages() output: Dict[str, Dict[str, str | float | int]] = {} for evt in key_average_event_list: - if evt.trace_name is None: + trace_name = getattr(evt, "trace_name", None) + if trace_name is None: continue output[evt.trace_name] = { "start": evt.time_range.start, @@ -167,6 +169,23 @@ def handle_profiler_trace(profiler_output: torch.autograd.profiler.profile): write_metrics("profiler_table", key_average_event_list.table()) parsed_output = parse_profiler_output(profiler_output) write_metrics("profiler_output", parsed_output) + profiler_output.export_chrome_trace(profiler_output_path) + + +class BenchmarkingDataset(Dataset): + def load_datasets(self) -> None: + """Load datasets into memory.""" + self.train_ds = [ + { + "id": row_id, + "input_ids": torch.zeros(1024), + "labels": torch.zeros(1024), + "attention_mask": torch.ones(1024), + } + for row_id in range(1024) + ] + self.eval_ds = self.train_ds + self.original_length = math.ceil(len(self.train_ds) / self.train_bs) def main(config: Config, model_name: str) -> None: @@ -211,6 +230,8 @@ def main(config: Config, model_name: str) -> None: training_args.use_mp, get_is_flash_attention_supported(), training_args.max_seq_len, + local_rank, + training_args.low_cpu_mem_usage, ) if lora_peft_config is not None: model = get_lora_model_from_base_model(model, lora_peft_config) @@ -228,6 +249,8 @@ def main(config: Config, model_name: str) -> None: training_args.use_mp, training_args.use_activation_checkpointing, training_args.sharding_strategy, + local_rank, + training_args.low_cpu_mem_usage, ) with track_time("set_activation_checkpointing"): @@ -236,7 +259,7 @@ def main(config: Config, model_name: str) -> None: # load dataset with track_time("dataset_load"): - dataset = Dataset( + dataset = BenchmarkingDataset( config=config.dataset, tokenizer=tokenizer, ) diff --git a/configs/config-lora-benchmark.yaml b/profiling/configs/lora-benchmark.yaml similarity index 92% rename from configs/config-lora-benchmark.yaml rename to profiling/configs/lora-benchmark.yaml index 02812e9..dbd6dfe 100644 --- a/configs/config-lora-benchmark.yaml +++ b/profiling/configs/lora-benchmark.yaml @@ -25,6 +25,8 @@ train_parameters: use_activation_checkpointing: True # use_flash_attention is automatically enabled # for CUDA capability > 8.0 + low_cpu_mem_usage: True + # Gradient norm clipping max_grad_norm: 1 @@ -50,8 +52,8 @@ dataset: ignore_index: -100 eval_bs: 8 train_bs: 8 - train_ds: data/processed/gsm8k-question/train - eval_ds: data/processed/gsm8k-question/test + train_ds: /dev/null + eval_ds: /dev/null dataset_preprocess: ignore_index: -100 diff --git a/launch_benchmark.py b/profiling/launch_benchmark.py similarity index 100% rename from launch_benchmark.py rename to profiling/launch_benchmark.py diff --git a/pyproject.toml b/pyproject.toml index 3cf671e..c9ca70d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta" [tool.ruff] line-length = 80 -select = ["ALL"] -ignore = [ +lint.select = ["ALL"] +lint.ignore = [ "ANN101", "FBT", "D100", @@ -15,9 +15,10 @@ ignore = [ "N817", "TCH001", "E731", - "PLR0913" + "PLR0913", + "T201" ] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Ignore `F401` (import violations) in all `__init__.py` files. "__init__.py" = ["F401", "D104"] diff --git a/vectorlm/dataset.py b/vectorlm/dataset.py index 768ef8f..b383803 100644 --- a/vectorlm/dataset.py +++ b/vectorlm/dataset.py @@ -28,6 +28,7 @@ class Dataset: train_bs: A per-device batch size for training. eval_bs: A per-device batch size for evaluating. _processed_ids: A tensor of already trained examples. + """ def __init__( @@ -41,6 +42,7 @@ def __init__( ---- config: The dataset config. tokenizer: The input tokenizer. + """ self.config = config self._processed_ids = torch.tensor([]).to(torch.cuda.current_device()) diff --git a/vectorlm/tests/test_dataset.py b/vectorlm/tests/test_dataset.py new file mode 100644 index 0000000..7856487 --- /dev/null +++ b/vectorlm/tests/test_dataset.py @@ -0,0 +1,39 @@ +import pytest +from vectorlm.tests.test_modelling import setup_and_teardown_torch_process_group +from profiling.benchmark import BenchmarkingDataset + +from box import Box + +from transformers import AutoTokenizer + + +dataset_config = Box( + { + "ignore_index": -100, + "eval_bs": 8, + "train_bs": 8, + "train_ds": "/dev/null", + "eval_ds": "/dev/null", + } +) + + +@pytest.fixture() +def benchmark_dataset(setup_and_teardown_torch_process_group): + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") + return BenchmarkingDataset(dataset_config, tokenizer) # type: ignore + + +def test_initialize_dataset(benchmark_dataset): + print(benchmark_dataset) + + +def test_get_batch(benchmark_dataset): + benchmark_dataset.setup_dataloaders() + dataset_iterator = iter(benchmark_dataset.train_dataloader) + batch = next(dataset_iterator) + + for key in ["input_ids", "attention_mask"]: + assert len(batch[key].shape) == 2 # batch, tokens + + print(batch) diff --git a/vectorlm/tests/test_modelling.py b/vectorlm/tests/test_modelling.py index f3a7af4..c0d8403 100644 --- a/vectorlm/tests/test_modelling.py +++ b/vectorlm/tests/test_modelling.py @@ -3,6 +3,7 @@ """ from collections import Counter, defaultdict +import os import re import pytest @@ -24,6 +25,8 @@ get_submodule_by_pattern, ) +local_rank = int(os.environ.get("LOCAL_RANK", 0)) + @pytest.fixture() def setup_and_teardown_torch_process_group(): @@ -57,7 +60,9 @@ def lora_peft_config(): @pytest.fixture() def base_model(): - model, tokenizer = load_model_and_tokenizer("facebook/opt-125m", True, False, 1024) + model, tokenizer = load_model_and_tokenizer( + "/model-weights/opt-350m", True, False, 1024, local_rank, True + ) return model @@ -69,13 +74,17 @@ def lora_model(base_model, lora_peft_config): @pytest.fixture() def base_model_sharded(base_model, setup_and_teardown_torch_process_group): - model_sharded = shard_model(base_model, OPTDecoderLayer, True, True, "FULL_SHARD") + model_sharded = shard_model( + base_model, OPTDecoderLayer, True, True, "FULL_SHARD", local_rank, True + ) return model_sharded @pytest.fixture() def lora_model_sharded(lora_model, setup_and_teardown_torch_process_group): - model_sharded = shard_model(lora_model, OPTDecoderLayer, True, True, "FULL_SHARD") + model_sharded = shard_model( + lora_model, OPTDecoderLayer, True, True, "FULL_SHARD", local_rank, True + ) return FSDP(model_sharded, device_id=torch.cuda.current_device()) @@ -99,13 +108,8 @@ def batch(): return batch -def test_load_model_and_tokenizer(): - """ - Test load base model and tokenizer. - """ - model, tokenizer = load_model_and_tokenizer("facebook/opt-125m", True, True, 1024) - - print("type(model): {}".format(type(model))) +def test_load_base_model(base_model): + print(base_model) def test_match_submodule_by_pattern(base_model, lora_model): @@ -120,14 +124,16 @@ def test_match_submodule_by_pattern(base_model, lora_model): assert submodule == OPTDecoderLayer -def test_partition_base_model(base_model, setup_and_teardown_torch_process_group): +def test_partition_base_model( + base_model_sharded, setup_and_teardown_torch_process_group +): """ Test partitioning base model (no lora/peft). """ - base_model = shard_model(base_model, OPTDecoderLayer, True, True, "FULL_SHARD") - output_text = [] - for parameter_name, parameter in base_model.named_parameters(): + for parameter_name, parameter in base_model_sharded.named_parameters(): + requires_grad = parameter.requires_grad + assert requires_grad == True output_text.append("{}\t{}".format(requires_grad, parameter_name)) with open("data/output_base.txt", "w") as output_file: diff --git a/vectorlm/trainer.py b/vectorlm/trainer.py index 6fc8302..8fa61d5 100644 --- a/vectorlm/trainer.py +++ b/vectorlm/trainer.py @@ -7,11 +7,11 @@ import torch import torch.distributed as dist +import wandb from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler, ReduceLROnPlateau from transformers import PreTrainedTokenizer -import wandb from vectorlm.dataset import Dataset from vectorlm.utils.data_utils import Config from vectorlm.utils.save_utils import ( @@ -57,6 +57,7 @@ class Trainer: epoch. max_steps: An integer maximum number of training steps for this run. saving_steps: An integer for how often we save. + """ def __init__( @@ -128,6 +129,7 @@ def prepare_trainer( dataset: The `Dataset` class. optimizer: The training optimizer. lr_scheduler: The LR scheduler. + """ self.model = model self.tokenizer = tokenizer @@ -141,6 +143,7 @@ def save_checkpoint(self, epoch: int) -> None: Args: ---- epoch: The current training epoch. + """ rank = dist.get_rank() gathered_processed_ids = _gather( @@ -182,6 +185,7 @@ def load_checkpoint(self, checkpoint_dir: str) -> int: Returns: ------- The checkpointed epoch to be used by the outer loop. + """ rank = dist.get_rank() step, epoch, ids = load_metadata(checkpoint_dir) @@ -205,6 +209,7 @@ def find_checkpoint(self, checkpoint_dir: str) -> int: ------- The checkpointed epoch. If no checkpoint exists, it returns a default value of 0. + """ checkpoint = checkpoint_exists(checkpoint_dir) if checkpoint: @@ -231,6 +236,7 @@ def step( ---- train_batch: The training batch. epoch: The current training epoch. + """ if ( self.config.checkpointing_enabled @@ -257,6 +263,7 @@ def train_step(self, batch: dict[str, torch.Tensor], epoch: int) -> float: ---- batch: The training batch. epoch: The current training epoch. + """ ids = batch.pop("id").to(torch.cuda.current_device()) batch["input_ids"] = batch["input_ids"].type(torch.LongTensor) @@ -312,6 +319,7 @@ def eval_step(self, epoch: int) -> float: Args: ---- epoch: The current training epoch. + """ print_main("Evaluating") self.model.eval() @@ -348,6 +356,7 @@ def log(self, loss: float, epoch: int, mode: str = "train") -> None: loss: The loss being logged. epoch: The current training epoch. mode: One of `train` or `eval`. + """ if mode not in {"train", "eval"}: msg = "`mode` argument needs to be 'train' or 'eval'." diff --git a/vectorlm/utils/convert_to_hf.py b/vectorlm/utils/convert_to_hf.py index 3e66dce..c4f2405 100644 --- a/vectorlm/utils/convert_to_hf.py +++ b/vectorlm/utils/convert_to_hf.py @@ -14,6 +14,7 @@ def parse_args() -> Namespace: Returns ------- The parsed arguments. + """ parser = argparse.ArgumentParser() parser.add_argument("--config_path", default="configs/config.yaml") @@ -28,6 +29,7 @@ def converter(config: Config) -> None: Args: ---- config: The full config. + """ state_dict = torch.load( os.path.join( diff --git a/vectorlm/utils/data_utils.py b/vectorlm/utils/data_utils.py index 4e14dd4..ddf021f 100644 --- a/vectorlm/utils/data_utils.py +++ b/vectorlm/utils/data_utils.py @@ -12,6 +12,7 @@ class Config: ---------- yaml_path: A path to the yaml file that stores our config. to_box: A boolean indicating whether to box our config. + """ def __init__(self, yaml_path: str, to_box: bool = True) -> None: @@ -21,6 +22,7 @@ def __init__(self, yaml_path: str, to_box: bool = True) -> None: ---- yaml_path: The string path to the config yaml. to_box: Defines whether this initialization will use dot notation. + """ self.yaml_path = yaml_path self.to_box = to_box @@ -55,6 +57,7 @@ class DataCollatorWithPadding: ignore_index: A value used for ignoring a given token in labels. max_seq_len: An integer denoting the maximum sequence length. padding_side: A side of the sequence that gets padded. + """ def __init__( @@ -73,6 +76,7 @@ def __init__( loss. max_seq_len: The maximum sequence length to expect. padding_side: The side of the sequence which is padded. + """ self.pad_token_id = pad_token_id self.ignore_index = ignore_index @@ -99,6 +103,7 @@ def __call__( Returns: ------- A dictionary containing a batch that we can input to our model. + """ batch = {} keys = ["input_ids", "labels"] diff --git a/vectorlm/utils/misc_utils.py b/vectorlm/utils/misc_utils.py index 30c1d67..6e8cc1a 100644 --- a/vectorlm/utils/misc_utils.py +++ b/vectorlm/utils/misc_utils.py @@ -5,6 +5,7 @@ import torch.distributed as dist import wandb + from vectorlm.utils.data_utils import Config diff --git a/vectorlm/utils/model_utils.py b/vectorlm/utils/model_utils.py index de36839..da6957c 100644 --- a/vectorlm/utils/model_utils.py +++ b/vectorlm/utils/model_utils.py @@ -2,7 +2,7 @@ import functools import re -from typing import Any, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple import torch import torch.distributed as dist @@ -77,6 +77,7 @@ def load_peft_model_and_tokenizer( Returns: ------- The PEFT model and tokenizer. + """ model, tokenizer = load_model_and_tokenizer( path, @@ -94,11 +95,15 @@ def load_peft_model_and_tokenizer( return peft_model, tokenizer + def load_model_and_tokenizer( path: str, use_mp: bool, use_fa: bool, max_seq_len: int, + local_rank: int, + low_cpu_mem_usage: bool, + use_safetensors: bool = True, ) -> tuple[PreTrainedModel, PreTrainedTokenizer]: """Load the model and tokenizer. @@ -108,13 +113,19 @@ def load_model_and_tokenizer( use_mp: Whether to use mixed-precision. use_fa: Whether to use Flash Attention 2. max_seq_len: The maximum sequence length. + local_rank: The local rank of the current worker. + low_cpu_mem_usage: Whether to only load model weights on main rank, and + then scatter them to the other workers. + use_safetensors: Whether to use HF safe tensors. Note that this format + loads significantly faster. Returns: ------- The model and tokenizer. + """ # load model - model_args = {"use_cache": False} + model_args = {"use_cache": False, "use_safetensors": use_safetensors} if use_mp: model_args["torch_dtype"] = torch.bfloat16 @@ -123,10 +134,18 @@ def load_model_and_tokenizer( msg = "Use FA with bf16 (mixed precision)" raise ValueError(msg) model_args["attn_implementation"] = "flash_attention_2" - model = AutoModelForCausalLM.from_pretrained( - path, - **model_args, - ) + + if not low_cpu_mem_usage or local_rank == 0: + model = AutoModelForCausalLM.from_pretrained( + path, + **model_args, + ) + else: + with torch.device("meta"): + model = AutoModelForCausalLM.from_pretrained( + path, + **model_args, + ) # load tokenizer tokenizer = AutoTokenizer.from_pretrained(path) @@ -143,7 +162,13 @@ def load_model_and_tokenizer( return model, tokenizer -def fsdp_config(use_mp: bool, model: nn.Module, strategy: str) -> dict[str, Any]: +def fsdp_config( + use_mp: bool, + model: nn.Module, + strategy: str, + local_rank: int, + low_cpu_mem_usage: bool, +) -> dict[str, Any]: """Get FSDP config. Args: @@ -151,11 +176,23 @@ def fsdp_config(use_mp: bool, model: nn.Module, strategy: str) -> dict[str, Any] use_mp: Whether to use mixed-precision. model_to_wrap: The HuggingFace model to wrap using FSDP. strategy: The sharding strategy to use. + local_rank: The local rank of the current worker. + low_cpu_mem_usage: Whether to only load model weights on main rank, and + then scatter them to the other workers. Returns: ------- A dictionary containing the configurations. + """ + + def _module_init_fn(module: nn.Module) -> Callable: + """Return the function used for initializing modules on FSDP workers.""" + return module.to_empty( + device=torch.cuda.current_device(), + recurse=False, + ) + strategy_exists = hasattr(ShardingStrategy, strategy) if not strategy_exists: msg = f"The sharding strategy {strategy} does not exist." @@ -175,6 +212,9 @@ def fsdp_config(use_mp: bool, model: nn.Module, strategy: str) -> dict[str, Any] ret_dict["auto_wrap_policy"] = fsdp_auto_wrap_policy(model) ret_dict["sharding_strategy"] = sharding_strategy ret_dict["device_id"] = torch.cuda.current_device() + if low_cpu_mem_usage: + ret_dict["param_init_fn"] = _module_init_fn if local_rank != 0 else None + ret_dict["sync_module_states"] = True return ret_dict @@ -184,6 +224,8 @@ def shard_model( use_mp: bool, use_activation_checkpointing: bool, strategy: str, + local_rank: int, + low_cpu_mem_usage: bool, ) -> nn.Module: """Shard the model to workers using FSDP. @@ -194,12 +236,18 @@ def shard_model( use_mp: Whether to use mixed-precision. use_activation_checkpointing: Whether to use activation checkpointing. strategy: The sharding strategy to use. + local_rank: The local rank of the current worker. + low_cpu_mem_usage: Whether to only load model weights on main rank, and + then scatter them to the other workers. Returns: ------- The sharded module with the requested configurations. + """ - fsdp_cfg = fsdp_config(use_mp, model, strategy) + fsdp_cfg = fsdp_config( + use_mp, model, strategy, local_rank, low_cpu_mem_usage, + ) if dist.get_rank() == 0: print(f"FSDP config: {fsdp_cfg}") model = FSDP(model, **fsdp_cfg) @@ -223,6 +271,7 @@ def hook_activation_checkpointing( ---- model: The model we are using. layer: The layer to which we hook activation checkpointing to. + """ non_reentrant_wrapper = functools.partial( checkpoint_wrapper, diff --git a/vectorlm/utils/optimizer_utils.py b/vectorlm/utils/optimizer_utils.py index ff479a9..2d99d04 100644 --- a/vectorlm/utils/optimizer_utils.py +++ b/vectorlm/utils/optimizer_utils.py @@ -31,6 +31,7 @@ class PlateaeuWithWarmup(ReduceLROnPlateau): The maximum LR is determined by the number of warmup steps and the current step. base_lrs: A list of base LRs for the optimizer's param groups. + """ def __init__( @@ -63,6 +64,7 @@ def __init__( otherwise the update is ignored. verbose: Whether to print messages to stdout every LR update. num_warmup_steps: The number of steps we warmup the LR for. + """ super().__init__( optimizer=optimizer, @@ -85,6 +87,7 @@ def step(self, metrics: float, epoch: int | None = None) -> None: --------- metrics: The metric we are using to measure change in LR. epoch: The current step. + """ if epoch is None: epoch = self.last_epoch + 1 @@ -159,9 +162,11 @@ def get_custom_scheduler( name: The name of the scheduler args: The scheduler specific args. kwargs: The scheduler specific kwargs. - + Returns: + ------- The scheduler. + """ if name == "plataeu-with-warmup": scheduler = PlateaeuWithWarmup(*args, **kwargs) diff --git a/vectorlm/utils/save_utils.py b/vectorlm/utils/save_utils.py index 36998af..816e2ec 100644 --- a/vectorlm/utils/save_utils.py +++ b/vectorlm/utils/save_utils.py @@ -28,6 +28,7 @@ def checkpoint_exists(output_dir: str) -> bool: Returns: ------- Returns whether a checkpoint exists. + """ if os.path.isdir(os.path.join(output_dir, "checkpoints")): return True @@ -44,6 +45,7 @@ def save_metadata( ---- out_dir: The directory to save to. meta_dict: The dictionary containing the meta data. + """ os.makedirs(out_dir, exist_ok=True) torch.save(meta_dict, os.path.join(out_dir, "meta_data.pkl")) @@ -62,6 +64,7 @@ def load_metadata( ------- A tuple containing the checkpointed step, epoch, and the processed training dataset ids. + """ save_path = os.path.join(in_dir, "meta_data.pkl") meta_dict = torch.load(save_path) @@ -81,6 +84,7 @@ def get_latest_checkpoint_dir(folder_path: str) -> str: Returns: ------- The subpath (i.e. two levels) of the latest checkpoint's directory. + """ epoch_pattern = re.compile(r"^epoch_(\d+)$") folder_pattern = re.compile(r"^checkpoint_(\d+)$") @@ -112,6 +116,7 @@ def save_model(model: nn.Module, output_dir: str, rank: int) -> None: model: The sharded model. output_dir: The checkpointing directory. rank: The worker's rank. + """ os.makedirs(output_dir, exist_ok=True) weights_name = f"model_rank{rank}.bin" @@ -131,6 +136,7 @@ def load_model(model: nn.Module, input_dir: str, rank: int) -> None: model: The sharded model. input_dir: The checkpointing directory. rank: The worker's rank. + """ weights_name = f"model_rank{rank}.bin" input_model_file = os.path.join(input_dir, weights_name) @@ -154,6 +160,7 @@ def save_consolidated_model( model: The sharded model. save_dir: The checkpointing directory. rank: The worker's rank. + """ os.makedirs(save_dir, exist_ok=True) cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) @@ -178,6 +185,7 @@ def save_optimizer( model: The sharded model. output_dir: The checkpointing directory. rank: The worker's rank. + """ opt_name = f"optimizer_rank{rank}.bin" output_optimizer_file = os.path.join(output_dir, opt_name) @@ -207,6 +215,7 @@ def load_optimizer( model: The sharded model. input_dir: The checkpointing directory. rank: The worker's rank. + """ opt_name = f"optimizer_rank{rank}.bin" input_optimizer_file = os.path.join(input_dir, opt_name) @@ -237,6 +246,7 @@ def save_scheduler( scheduler: The LR scheduler. output_dir: The checkpointing directory. rank: The worker's rank. + """ sched_name = f"scheduler_rank{rank}.bin" output_scheduler_file = os.path.join(output_dir, sched_name) @@ -258,6 +268,7 @@ def load_scheduler( scheduler: The LR scheduler. input_dir: The checkpointing directory. rank: The worker's rank. + """ sched_name = f"scheduler_rank{rank}.bin" input_scheduler_file = os.path.join(input_dir, sched_name) From 48f61d92cf66ce88a31303a857d8fcd0a1901474 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 11 Mar 2024 17:01:00 -0400 Subject: [PATCH 20/89] Revised launch_benchmark.py to use new profiling path. --- profiling/launch_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/profiling/launch_benchmark.py b/profiling/launch_benchmark.py index f54e758..e76b71c 100644 --- a/profiling/launch_benchmark.py +++ b/profiling/launch_benchmark.py @@ -38,7 +38,7 @@ slurm_flags_extra = {"time": "00:30:00", "qos": qos_selected} -slurm_pos_args_options = [["examples/launch_lora_benchmark.sh"], model_list] +slurm_pos_args_options = [["profiling/launch_lora_benchmark.sh"], model_list] timestamp = int(time.time()) args_list: List[List[str]] = [] From 9876ebe150a42ca7294a9b4c588462408b3b3559 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 11 Mar 2024 17:09:14 -0400 Subject: [PATCH 21/89] Enabled automatic creation of data/trace folder. --- profiling/benchmark.py | 1 + 1 file changed, 1 insertion(+) diff --git a/profiling/benchmark.py b/profiling/benchmark.py index fb2ffe8..00ea7be 100644 --- a/profiling/benchmark.py +++ b/profiling/benchmark.py @@ -55,6 +55,7 @@ def parse_args() -> Namespace: # unix timestamp launch_time = time.time() os.makedirs("data/benchmark", exist_ok=True) +os.makedirs("data/trace", exist_ok=True) output_path = "data/benchmark/{}.jsonl".format(launch_time) profiler_output_path = "data/trace/{}.json".format(launch_time) From 53308717155ae8835b0812a2bbbb8f33c78a1b99 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 11 Mar 2024 17:13:17 -0400 Subject: [PATCH 22/89] Added instructions for profiling tools. --- README.md | 2 ++ profiling/README.md | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+) create mode 100644 profiling/README.md diff --git a/README.md b/README.md index d98e86a..99e3d94 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,8 @@ At the end of training, a consolidated model will be saved under your output dir We provide an additional example of parameter-efficient fine-tuning (PEFT) using LoRA and FSDP. Use the [`examples/launch_lora.sh`](examples/launch_lora.sh) to launch your job on the cluster. +At the end of the training, the LoRA adapter folder will be saved in your output directory. This folder can be loaded directly through the `peft` library through the + # Contributors Adil Asif, Ziwen Han, John Willes, Jacob-Junqi Tian. diff --git a/profiling/README.md b/profiling/README.md new file mode 100644 index 0000000..143379b --- /dev/null +++ b/profiling/README.md @@ -0,0 +1,21 @@ +# Profiling Utils + +To modify the specific SLURM resources types to benchmark, adjust the launcher script `launch_benchmark.py` as needed. Modify `profiling/configs/lora-benchmark.yaml` to adjust parameters such as batch size and token width. + +On the Vector cluster, run the following to launch the benchmarks: + +```bash +$ mkdir data/ +$ python3 launch_benchmark.py + +# The launcher script will print a list of +# SLURM commands it plans to run. Press ENTER +# to accept and automatically invoke the comands. +``` + +After the SLURM jobs complete, profiler output can be found under `/data/benchmark`. + + +## TODO + +Add script for automatically parsing profiler output files. \ No newline at end of file From 9982791d2bb6bac695d904bc5f231c3a5c789b33 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 11 Mar 2024 17:26:21 -0400 Subject: [PATCH 23/89] Cleaned up duplicate imports from merge. --- parse_benchmark.py | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 parse_benchmark.py diff --git a/parse_benchmark.py b/parse_benchmark.py new file mode 100644 index 0000000..1f7ea71 --- /dev/null +++ b/parse_benchmark.py @@ -0,0 +1,11 @@ +""" +Parse benchmarking results +to generate metrics overview table. +""" + +import json + + +# Load all files + +# From 9a76e80c7cad25926d204dfbd4163e9ca1b87106 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 11 Mar 2024 17:26:56 -0400 Subject: [PATCH 24/89] Cleaned up duplicate imports from merge. --- vectorlm/trainer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/vectorlm/trainer.py b/vectorlm/trainer.py index b07c70e..9c024fe 100644 --- a/vectorlm/trainer.py +++ b/vectorlm/trainer.py @@ -8,7 +8,6 @@ import torch import torch.distributed as dist import wandb -import wandb from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler, ReduceLROnPlateau from transformers import PreTrainedTokenizer @@ -337,10 +336,7 @@ def eval_step(self, epoch: int) -> float: batch["input_ids"] = batch["input_ids"].type(torch.LongTensor) num_tokens = len(batch["input_ids"].flatten()) batch["labels"] = batch["labels"].type(torch.LongTensor) - batch = { - k: v.to(torch.cuda.current_device()) - for k, v in batch.items() - } + batch = {k: v.to(torch.cuda.current_device()) for k, v in batch.items()} with self.timer_handle("eval_step", {"num_tokens": num_tokens}): out = self.model(**batch) From ffa7067e4c5e790a86954f5878a063c1bf1c725c Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 11 Mar 2024 17:27:53 -0400 Subject: [PATCH 25/89] Cleaned up parse_benchmark.py --- parse_benchmark.py | 11 ----------- 1 file changed, 11 deletions(-) delete mode 100644 parse_benchmark.py diff --git a/parse_benchmark.py b/parse_benchmark.py deleted file mode 100644 index 1f7ea71..0000000 --- a/parse_benchmark.py +++ /dev/null @@ -1,11 +0,0 @@ -""" -Parse benchmarking results -to generate metrics overview table. -""" - -import json - - -# Load all files - -# From bd893e1b643b1a67d1ce6f42a560f471024a72ad Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 11 Mar 2024 17:32:21 -0400 Subject: [PATCH 26/89] Integrated LoRA logic into llama_example.py. --- examples/example_lora.py | 161 --------------------------------------- 1 file changed, 161 deletions(-) delete mode 100644 examples/example_lora.py diff --git a/examples/example_lora.py b/examples/example_lora.py deleted file mode 100644 index 2f3a61e..0000000 --- a/examples/example_lora.py +++ /dev/null @@ -1,161 +0,0 @@ -# Renamed from examples/llama_example.py -import argparse -import math -import os -import sys -from argparse import Namespace - -import torch -import torch.distributed as dist -from torch.optim import AdamW -from tqdm import tqdm -from transformers import set_seed -from transformers.models.llama.modeling_llama import LlamaDecoderLayer - -from vectorlm.dataset import Dataset -from vectorlm.trainer import Trainer -from vectorlm.utils.data_utils import Config -from vectorlm.utils.misc_utils import cleanup, setup, wandb_setup -from vectorlm.utils.model_utils import ( - hook_activation_checkpointing, - initialize_lora_model_and_tokenizer, - shard_model, -) -from vectorlm.utils.optimizer_utils import get_custom_scheduler -from vectorlm.utils.save_utils import save_consolidated_model - - -def parse_args() -> Namespace: - """Parse command-line arguments. - - Returns - ------- - The parsed arguments. - """ - parser = argparse.ArgumentParser() - parser.add_argument( - "--yaml_path", - default="configs/config.yaml", - required=False, - ) - return parser.parse_args() - - -def main(config: Config) -> None: - """Define the main calling function.""" - training_args = config.train_parameters - - # set a seed - set_seed(training_args.seed) - - # set CUDA related dependencies - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - - print(f"Rank: {rank}, World size: {world_size}") - if dist.is_initialized(): - torch.cuda.set_device(local_rank) - torch.cuda.empty_cache() - - # setup wandb - if rank == 0: - wandb_setup(config, **config.wandb_config) - dist.barrier() - - # load model and tokenizer - state_dict_path = getattr(config, "state_dict", None) - - model, tokenizer = initialize_lora_model_and_tokenizer( - config.model, - training_args.use_mp, - training_args.use_flash_attention, - training_args.max_seq_len, - config.lora_peft_config, - ) - - model = shard_model( - model, - LlamaDecoderLayer, - training_args.use_mp, - training_args.use_activation_checkpointing, - training_args.sharding_strategy, - ) - - if training_args.use_activation_checkpointing: - hook_activation_checkpointing(model, LlamaDecoderLayer) - - # load dataset - dataset = Dataset( - config=config.dataset, - tokenizer=tokenizer, - ) - - # instantiate trainer - trainer = Trainer( - config=training_args, - enable_wandb_logging=config.enable_wandb_logging, - original_dataset_length=dataset.original_length, - ) - - # load optimizer - optimizer = AdamW( - model.parameters(), - **training_args.optimizer, - ) - - # load lr scheduler - lr_scheduler = get_custom_scheduler( - training_args.lr_scheduler_type, - optimizer, - math.ceil( - trainer.num_update_steps_per_epoch * training_args.warmup_ratio, - ), - trainer.max_steps, - ) - - trainer.prepare_trainer( - model, - tokenizer, - dataset, - optimizer, - lr_scheduler, - ) - - # TODO: support restoring LoRA fine-tuning - trainer.dataset.setup_dataloaders() - checkpointed_epoch = 0 - - for epoch in range(checkpointed_epoch, training_args.epochs): - trainer.model.train() - train_dl_iterator = iter(dataset.train_dataloader) - for _ in tqdm( - range(len(dataset.train_dataloader)), - disable=rank != 0, - file=sys.__stdout__, - ): - batch = next(train_dl_iterator) - trainer.step(batch, epoch) - - if epoch == training_args.epochs - 1: - hf_save_dir = os.path.join(training_args.output_dir, "final-model") - else: - hf_save_dir = os.path.join( - training_args.output_dir, - "checkpoints", - f"epoch_{epoch}", - "end-epoch-model", - ) - save_consolidated_model(trainer.model, hf_save_dir, rank) - if rank == 0: - tokenizer.save_pretrained(hf_save_dir) - - dataset.reset_dataloaders() - - -if __name__ == "__main__": - args = parse_args() - config = Config(yaml_path=args.yaml_path) - setup(config.train_parameters.output_dir) - main(config) - cleanup() From c2f346f26fed8e3f01ba4065d7a6ab30f6641583 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 11 Mar 2024 17:41:05 -0400 Subject: [PATCH 27/89] Moved lora_configs into train_parameters in config yaml. Adjusted docs/config.md accordingly. --- configs/config-lora.yaml | 66 --------------------------- configs/config.yaml | 8 ++++ docs/config.md | 1 + examples/llama_example.py | 19 ++++++-- profiling/benchmark.py | 2 +- profiling/configs/lora-benchmark.yaml | 13 +++--- 6 files changed, 32 insertions(+), 77 deletions(-) delete mode 100644 configs/config-lora.yaml diff --git a/configs/config-lora.yaml b/configs/config-lora.yaml deleted file mode 100644 index e9eca95..0000000 --- a/configs/config-lora.yaml +++ /dev/null @@ -1,66 +0,0 @@ -model: /model-weights/Llama-2-7b-chat-hf -enable_wandb_logging: True - -lora_peft_config: - task_type: CAUSAL_LM - inference_mode: False - r: 8 - lora_alpha: 32 - lora_dropout: 0.1 - -wandb_config: - project: MedGPT - name: Llama-2-7B-chat - -train_parameters: - output_dir: your/output/dir - max_seq_len: 1024 - epochs: 1 - seed: 11 - - # Sharding strategy - sharding_strategy: FULL_SHARD - - # Memory - use_mp: True - use_activation_checkpointing: True - use_flash_attention: True - - # Gradient norm clipping - max_grad_norm: 1 - gradient_accumulation_steps: 4 - - # Optimizer - optimizer: - lr: 2.0e-5 - weight_decay: 0.1 - betas: [0.9, 0.95] - eps: 1.0e-5 - - # Scheduler - lr_scheduler_type: cosine - warmup_ratio: 0.05 - - # Checkpointing - checkpointing_enabled: True - logging_steps: 500 - save_frequency: 0.25 - -dataset: - ignore_index: -100 - eval_bs: 8 - train_bs: 8 - train_ds: your/train/ds - eval_ds: your/eval/ds - -dataset_preprocess: - ignore_index: -100 - dataset_format: hf - data_field: text - packing_type: partial - add_bos_eos_tokens: True - from_disk: True - load_path: your/unprocessed/dataset - split: train - save_path: dir/to/save/processed/dataset - truncate: False diff --git a/configs/config.yaml b/configs/config.yaml index 4ce2dde..b7bc77f 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -20,6 +20,14 @@ train_parameters: use_flash_attention: True low_cpu_mem_usage: True + # # Uncomment below to enable LoRA + # lora_peft_config: + # task_type: CAUSAL_LM + # inference_mode: False + # r: 8 + # lora_alpha: 32 + # lora_dropout: 0.1 + # Gradient norm clipping max_grad_norm: 1 gradient_accumulation_steps: 4 diff --git a/docs/config.md b/docs/config.md index 0541c10..ef0de37 100644 --- a/docs/config.md +++ b/docs/config.md @@ -29,6 +29,7 @@ The key-value pairs stored under `wandb_config` are directly passed into the [`w * `use_activation_checkpointing`: Whether to use activation checkpointing. This greatly reduces memory footprint as only a few intermediate activations as saved during the forward pass, and are then recomputed for the backward pass on the fly. However, the tradeoff between compute vs. memory usually makes this worth it. * `use_flash_attention`: Whether to use Flash Attention. If it is supported for your model in HuggingFace, you can enable this option. * `low_cpu_mem_usage`: Whether to efficiently load the model. If enabled, the model weights are only loaded once on rank 0 and are broadcasted to the rest of the world from the main rank. It will prevent the CPU memory from exploding when loading large models (e.g. LLaMa-70B). +* `lora_peft_config`: Optionally, fine-tune the model using LoRA, using the HuggingFace PEFT implementation. Uncomment this section to enable LoRA. If LoRA is enabled, training output will consist only of the lora adapters. You can merge the lora adapter into the base model using utilities from the PEFT library. ### Gradient diff --git a/examples/llama_example.py b/examples/llama_example.py index d794740..3d284ec 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -15,7 +15,12 @@ from vectorlm.trainer import Trainer from vectorlm.utils.data_utils import Config from vectorlm.utils.misc_utils import cleanup, setup, wandb_setup -from vectorlm.utils.model_utils import load_model_and_tokenizer, shard_model +from vectorlm.utils.model_utils import ( + get_lora_model_from_base_model, + get_submodule_by_pattern, + load_model_and_tokenizer, + shard_model, +) from vectorlm.utils.optimizer_utils import get_custom_scheduler from vectorlm.utils.save_utils import save_consolidated_model @@ -30,7 +35,9 @@ def parse_args() -> Namespace: """ parser = argparse.ArgumentParser() parser.add_argument( - "--yaml_path", default="configs/config.yaml", required=False, + "--yaml_path", + default="configs/config.yaml", + required=False, ) return parser.parse_args() @@ -67,9 +74,14 @@ def main(config: Config) -> None: training_args.low_cpu_mem_usage, ) + lora_peft_config = getattr(config.train_parameters, "lora_peft_config", None) + if lora_peft_config is not None: + model = get_lora_model_from_base_model(model, lora_peft_config) + + decoder_layer_module = get_submodule_by_pattern(model, r"DecoderLayer$") model = shard_model( model, - LlamaDecoderLayer, + decoder_layer_module, training_args.use_mp, training_args.use_activation_checkpointing, training_args.sharding_strategy, @@ -141,6 +153,7 @@ def main(config: Config) -> None: save_consolidated_model(trainer.model, hf_save_dir, rank) dataset.reset_dataloaders() + if __name__ == "__main__": args = parse_args() config = Config(yaml_path=args.yaml_path) diff --git a/profiling/benchmark.py b/profiling/benchmark.py index 00ea7be..4d3c6cd 100644 --- a/profiling/benchmark.py +++ b/profiling/benchmark.py @@ -223,7 +223,7 @@ def main(config: Config, model_name: str) -> None: # load model and tokenizer state_dict_path = getattr(config, "state_dict", None) - lora_peft_config = getattr(config, "state_dict", None) + lora_peft_config = getattr(config.train_parameters, "lora_peft_config", None) with track_time("model_load"): model, tokenizer = load_model_and_tokenizer( diff --git a/profiling/configs/lora-benchmark.yaml b/profiling/configs/lora-benchmark.yaml index dbd6dfe..7306d7b 100644 --- a/profiling/configs/lora-benchmark.yaml +++ b/profiling/configs/lora-benchmark.yaml @@ -1,12 +1,5 @@ enable_wandb_logging: True -lora_peft_config: - task_type: CAUSAL_LM - inference_mode: False - r: 8 - lora_alpha: 32 - lora_dropout: 0.1 - wandb_config: project: vector-lm-verify name: benchmark-lora @@ -27,6 +20,12 @@ train_parameters: # for CUDA capability > 8.0 low_cpu_mem_usage: True + lora_peft_config: + task_type: CAUSAL_LM + inference_mode: False + r: 8 + lora_alpha: 32 + lora_dropout: 0.1 # Gradient norm clipping max_grad_norm: 1 From 56cb750e9bad3061b0d7dd7ed8bd59020a958a15 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 11 Mar 2024 21:49:32 -0400 Subject: [PATCH 28/89] Revised handling of nproc-per-node in benchmark script. --- examples/launch.sh | 2 +- profiling/launch_lora_benchmark.sh | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 profiling/launch_lora_benchmark.sh diff --git a/examples/launch.sh b/examples/launch.sh index 25f6ea0..92f0ad1 100644 --- a/examples/launch.sh +++ b/examples/launch.sh @@ -23,4 +23,4 @@ export LOGLEVEL=INFO export PYTHONFAULTHANDLER=1 # export CUDA_LAUNCH_BLOCKING=0 -torchrun --nnodes=1 --nproc-per-node=4 llama_example.py --yaml_path ../configs/config.yaml +torchrun --nnodes=1 --nproc-per-node=${SLURM_STEP_GPUS} llama_example.py --yaml_path ../configs/config.yaml diff --git a/profiling/launch_lora_benchmark.sh b/profiling/launch_lora_benchmark.sh new file mode 100644 index 0000000..fb4487e --- /dev/null +++ b/profiling/launch_lora_benchmark.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export NCCL_DEBUG=WARN +export NCCL_DEBUG_SUBSYS=WARN + +# export TORCH_DISTRIBUTED_DEBUG=DETAIL # Uncomment these flags for debugging communication +# export TORCH_CPP_LOG_LEVEL=INFO +export LOGLEVEL=INFO +export PYTHONFAULTHANDLER=1 +# export CUDA_LAUNCH_BLOCKING=0 + +source ~/vectorlm/env/bin/activate +export PYTHONPATH=$PYTHONPATH:`pwd` + +nvidia-smi + +torchrun \ +--nnodes=1 \ +--nproc-per-node=${SLURM_STEP_GPUS} profiling/benchmark.py \ +--yaml_path profiling/configs/lora-benchmark.yaml \ +--model_name $1 \ No newline at end of file From 97ddd8c0a780a5ed8762c10b43d7580117ec7cf2 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 11 Mar 2024 21:49:51 -0400 Subject: [PATCH 29/89] Included parameter_count info in benchmark output. --- profiling/benchmark.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/profiling/benchmark.py b/profiling/benchmark.py index 4d3c6cd..81bcfa1 100644 --- a/profiling/benchmark.py +++ b/profiling/benchmark.py @@ -253,6 +253,14 @@ def main(config: Config, model_name: str) -> None: local_rank, training_args.low_cpu_mem_usage, ) + per_device_parameter_count = sum(p.numel() for p in model.parameters()) + track_time( + "parameter_count", + { + "per_device": per_device_parameter_count, + "total": per_device_parameter_count * world_size, + }, + ) with track_time("set_activation_checkpointing"): if training_args.use_activation_checkpointing: From 7c7a00053b6c6738533bd5767534af5529a1de01 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 11 Mar 2024 21:50:19 -0400 Subject: [PATCH 30/89] Implemented basic util for parsing benchmarking output. --- profiling/parse_benchmark.py | 69 ++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 profiling/parse_benchmark.py diff --git a/profiling/parse_benchmark.py b/profiling/parse_benchmark.py new file mode 100644 index 0000000..bebff54 --- /dev/null +++ b/profiling/parse_benchmark.py @@ -0,0 +1,69 @@ +""" +Parse benchmarking results +to generate metrics overview table. +""" + +from collections import defaultdict +import os +import json +import glob + +import pandas + +benchmark_artifact_folder = "data/benchmark/" + +# Load all benchmark result jsonl files +benchmark_jsonl_list = glob.glob("*.jsonl", root_dir=benchmark_artifact_folder) +raw_benchmarks = [] +for jsonl_filename in benchmark_jsonl_list: + jsonl_path = os.path.join(benchmark_artifact_folder, jsonl_filename) + with open(jsonl_path, "r") as jsonl_file: + benchmark_content = [ + json.loads(line) for line in jsonl_file.read().splitlines() + ] + benchmark_content.append({"name": "_source", "value": jsonl_path}) + + raw_benchmarks.append(benchmark_content) + +# (model_name, device) +aggregated_output = defaultdict(dict) +for raw_benchmark in raw_benchmarks: + example_output = {} + + # Need to implement alternative reducing method + # string: most recent + # number: summation + for line in raw_benchmark: + name = line["name"] + value = line["value"] + example_output[name] = value + + model_name = example_output.get("model_name") + if model_name is None: + continue + + model_name = model_name.split("/")[-1] + source_filename = example_output["_source"] + + device_info = example_output["device_info"] + device_name = device_info["device_name"] + world_size = device_info["world_size"] + device_description = "{} x{}".format(device_name, world_size) + + if world_size > 1: + print(source_filename) + + train_step = example_output.get("train_step") + if train_step is not None: + train_throughput = train_step["num_tokens"] / train_step["time_elapsed"] + else: + train_throughput = None + + aggregated_output[model_name][device_description] = train_throughput + +throughput_table = pandas.DataFrame(aggregated_output).T + +print(throughput_table) + +with open("data/benchmark/table.md", "w") as table_output_file: + table_output_file.write(throughput_table.to_markdown()) From f33e89ae35be7337c2f28b96dc655ecf0efe8580 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Tue, 12 Mar 2024 08:08:50 -0400 Subject: [PATCH 31/89] model_utils: Enabled low_cpu_mem_usage in auto model from_pretrained by default. --- vectorlm/utils/model_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vectorlm/utils/model_utils.py b/vectorlm/utils/model_utils.py index d43e1a2..8cfc333 100644 --- a/vectorlm/utils/model_utils.py +++ b/vectorlm/utils/model_utils.py @@ -118,7 +118,7 @@ def load_model_and_tokenizer( use_safetensors: Whether to use HF safe tensors. Note that this format loads significantly faster. local_rank: The local rank of the current worker. - + Returns: ------- The model and tokenizer. @@ -139,12 +139,14 @@ def load_model_and_tokenizer( if not low_cpu_mem_usage or local_rank == 0: model = AutoModelForCausalLM.from_pretrained( path, + low_cpu_mem_usage=True, **model_args, ) else: with torch.device("meta"): model = AutoModelForCausalLM.from_pretrained( path, + low_cpu_mem_usage=True, **model_args, ) From 35bdbcd9f9f741c03459f18a9705a8c85a10c76b Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Wed, 13 Mar 2024 10:59:46 -0400 Subject: [PATCH 32/89] launch_lora_benchmark.sh: implemented automatic identification of num_gpus. lora-benchmark: switched parse_benchmark: implemented option to specify benchmark artifact folder to load. --- profiling/benchmark.py | 1 + profiling/configs/lora-benchmark.yaml | 2 +- profiling/launch_benchmark.py | 6 +++--- profiling/launch_lora_benchmark.sh | 9 +++++++-- profiling/parse_benchmark.py | 20 ++++++++++++++------ 5 files changed, 26 insertions(+), 12 deletions(-) diff --git a/profiling/benchmark.py b/profiling/benchmark.py index 81bcfa1..28110d2 100644 --- a/profiling/benchmark.py +++ b/profiling/benchmark.py @@ -191,6 +191,7 @@ def load_datasets(self) -> None: def main(config: Config, model_name: str) -> None: """Define the main calling function.""" + print("Writing metrics to {}".format(output_path)) write_metrics("model_name", model_name) write_metrics("config", {**config.__dict__}) write_metrics("device_info", get_device_info()) diff --git a/profiling/configs/lora-benchmark.yaml b/profiling/configs/lora-benchmark.yaml index 7306d7b..4105404 100644 --- a/profiling/configs/lora-benchmark.yaml +++ b/profiling/configs/lora-benchmark.yaml @@ -5,7 +5,7 @@ wandb_config: name: benchmark-lora train_parameters: - output_dir: /tmp/lora-benchmark + output_dir: /dev/shm/lora-benchmark max_seq_len: 128 epochs: 1 seed: 11 diff --git a/profiling/launch_benchmark.py b/profiling/launch_benchmark.py index e76b71c..b752fbd 100644 --- a/profiling/launch_benchmark.py +++ b/profiling/launch_benchmark.py @@ -23,7 +23,7 @@ "opt-350m", "Llama-2-7b-hf", "Llama-2-13b-hf", - "Mistral-7B-v0.1", + "Mixtral-8x7B-Instruct-v0.1", ] ] @@ -31,12 +31,12 @@ "nodes": [1], "mem": [0], "ntasks-per-node": [1], - "cpus-per-gpu": [6], + "cpus-per-gpu": [3], "gres": ["gpu:{}".format(n + 1) for n in range(8)], "partition": ["t4v2", "a40", "a100"], } -slurm_flags_extra = {"time": "00:30:00", "qos": qos_selected} +slurm_flags_extra = {"time": "00:15:00", "qos": qos_selected} slurm_pos_args_options = [["profiling/launch_lora_benchmark.sh"], model_list] timestamp = int(time.time()) diff --git a/profiling/launch_lora_benchmark.sh b/profiling/launch_lora_benchmark.sh index fb4487e..d10ce7c 100644 --- a/profiling/launch_lora_benchmark.sh +++ b/profiling/launch_lora_benchmark.sh @@ -14,9 +14,14 @@ source ~/vectorlm/env/bin/activate export PYTHONPATH=$PYTHONPATH:`pwd` nvidia-smi +export num_gpus=`nvidia-smi -L | wc -l` +echo num_gpus: ${num_gpus} torchrun \ --nnodes=1 \ ---nproc-per-node=${SLURM_STEP_GPUS} profiling/benchmark.py \ +--nproc-per-node=${num_gpus} profiling/benchmark.py \ --yaml_path profiling/configs/lora-benchmark.yaml \ ---model_name $1 \ No newline at end of file +--model_name $1 + +# # clean up benchmarking artifacts as ops have requested +rm -rf /dev/shm/lora-benchmark \ No newline at end of file diff --git a/profiling/parse_benchmark.py b/profiling/parse_benchmark.py index bebff54..b4138a2 100644 --- a/profiling/parse_benchmark.py +++ b/profiling/parse_benchmark.py @@ -3,6 +3,7 @@ to generate metrics overview table. """ +import argparse from collections import defaultdict import os import json @@ -10,7 +11,11 @@ import pandas -benchmark_artifact_folder = "data/benchmark/" + +parser = argparse.ArgumentParser() +parser.add_argument("--folder", default="data/benchmark/") +args = parser.parse_args() +benchmark_artifact_folder = args.folder # Load all benchmark result jsonl files benchmark_jsonl_list = glob.glob("*.jsonl", root_dir=benchmark_artifact_folder) @@ -50,12 +55,11 @@ world_size = device_info["world_size"] device_description = "{} x{}".format(device_name, world_size) - if world_size > 1: - print(source_filename) - train_step = example_output.get("train_step") if train_step is not None: - train_throughput = train_step["num_tokens"] / train_step["time_elapsed"] + train_throughput = ( + world_size * train_step["num_tokens"] / train_step["time_elapsed"] + ) else: train_throughput = None @@ -65,5 +69,9 @@ print(throughput_table) -with open("data/benchmark/table.md", "w") as table_output_file: +with open( + os.path.join(benchmark_artifact_folder, "table.md"), "w" +) as table_output_file: table_output_file.write(throughput_table.to_markdown()) + +print(example_output.get("profiler_table")) From e6b2e594929835af5f9df7ef4533ba9f46a8a0f0 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Wed, 13 Mar 2024 11:06:21 -0400 Subject: [PATCH 33/89] requirements.txt: included accelerate to support low_cpu_mem loading. --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index bdf0065..04f0da2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +accelerate datasets transformers sentencepiece From db148face1776dc0df44270349b1126ac62ee47e Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Wed, 13 Mar 2024 11:21:30 -0400 Subject: [PATCH 34/89] benchmark.py: adjusted BenchmarkingDataset to avoid StopIteration exception. --- profiling/benchmark.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/profiling/benchmark.py b/profiling/benchmark.py index 28110d2..6c91c42 100644 --- a/profiling/benchmark.py +++ b/profiling/benchmark.py @@ -183,7 +183,7 @@ def load_datasets(self) -> None: "labels": torch.zeros(1024), "attention_mask": torch.ones(1024), } - for row_id in range(1024) + for row_id in range(8192) ] self.eval_ds = self.train_ds self.original_length = math.ceil(len(self.train_ds) / self.train_bs) @@ -236,6 +236,7 @@ def main(config: Config, model_name: str) -> None: training_args.low_cpu_mem_usage, ) if lora_peft_config is not None: + print("Enabling LoRA Wrapper.") model = get_lora_model_from_base_model(model, lora_peft_config) decoder_layer_module = get_submodule_by_pattern(model, r"DecoderLayer$") From 35f6c5de015769291cb6f003f77110165b35ec25 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Fri, 15 Mar 2024 11:53:21 -0400 Subject: [PATCH 35/89] benchmark.py: added env var flag to toggle export_trace --- profiling/benchmark.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/profiling/benchmark.py b/profiling/benchmark.py index 6c91c42..435b00b 100644 --- a/profiling/benchmark.py +++ b/profiling/benchmark.py @@ -170,7 +170,9 @@ def handle_profiler_trace(profiler_output: torch.autograd.profiler.profile): write_metrics("profiler_table", key_average_event_list.table()) parsed_output = parse_profiler_output(profiler_output) write_metrics("profiler_output", parsed_output) - profiler_output.export_chrome_trace(profiler_output_path) + + if bool(os.environ.get("PROFILER_EXPORT_TRACE")): + profiler_output.export_chrome_trace(profiler_output_path) class BenchmarkingDataset(Dataset): @@ -185,7 +187,7 @@ def load_datasets(self) -> None: } for row_id in range(8192) ] - self.eval_ds = self.train_ds + self.eval_ds = self.train_ds[: len(self.train_ds) // 10] self.original_length = math.ceil(len(self.train_ds) / self.train_bs) From 4a1251b1add7893d7eabc4121a3ad1faf33ff34a Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Fri, 15 Mar 2024 12:05:20 -0400 Subject: [PATCH 36/89] parse_benchmark: included profiler table in output file. launch_benchmark: automated folder creation. launch_lora_benchmark: included model info in slurm output. --- profiling/launch_benchmark.py | 4 +++- profiling/launch_lora_benchmark.sh | 1 + profiling/parse_benchmark.py | 31 +++++++++++++++++++++--------- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/profiling/launch_benchmark.py b/profiling/launch_benchmark.py index b752fbd..b982d52 100644 --- a/profiling/launch_benchmark.py +++ b/profiling/launch_benchmark.py @@ -6,6 +6,7 @@ import itertools import subprocess import time +from os import makedirs from typing import List from tqdm.auto import tqdm @@ -32,7 +33,7 @@ "mem": [0], "ntasks-per-node": [1], "cpus-per-gpu": [3], - "gres": ["gpu:{}".format(n + 1) for n in range(8)], + "gres": ["gpu:{}".format(n) for n in [1, 2, 4, 8]], "partition": ["t4v2", "a40", "a100"], } @@ -73,5 +74,6 @@ input("\nPress ENTER to launch {} job(s)".format(len(args_list))) +makedirs("data/output", exist_ok=True) for args in tqdm(args_list, ncols=75): subprocess.run(args) diff --git a/profiling/launch_lora_benchmark.sh b/profiling/launch_lora_benchmark.sh index d10ce7c..9a57050 100644 --- a/profiling/launch_lora_benchmark.sh +++ b/profiling/launch_lora_benchmark.sh @@ -16,6 +16,7 @@ export PYTHONPATH=$PYTHONPATH:`pwd` nvidia-smi export num_gpus=`nvidia-smi -L | wc -l` echo num_gpus: ${num_gpus} +echo model: $1 torchrun \ --nnodes=1 \ diff --git a/profiling/parse_benchmark.py b/profiling/parse_benchmark.py index b4138a2..e967020 100644 --- a/profiling/parse_benchmark.py +++ b/profiling/parse_benchmark.py @@ -8,6 +8,7 @@ import os import json import glob +from typing import List import pandas @@ -32,8 +33,9 @@ # (model_name, device) aggregated_output = defaultdict(dict) +profiler_tables = defaultdict(dict) for raw_benchmark in raw_benchmarks: - example_output = {} + benchmark_output = {} # Need to implement alternative reducing method # string: most recent @@ -41,21 +43,21 @@ for line in raw_benchmark: name = line["name"] value = line["value"] - example_output[name] = value + benchmark_output[name] = value - model_name = example_output.get("model_name") + model_name = benchmark_output.get("model_name") if model_name is None: continue model_name = model_name.split("/")[-1] - source_filename = example_output["_source"] + source_filename = benchmark_output["_source"] - device_info = example_output["device_info"] + device_info = benchmark_output["device_info"] device_name = device_info["device_name"] world_size = device_info["world_size"] device_description = "{} x{}".format(device_name, world_size) - train_step = example_output.get("train_step") + train_step = benchmark_output.get("train_step") if train_step is not None: train_throughput = ( world_size * train_step["num_tokens"] / train_step["time_elapsed"] @@ -64,14 +66,25 @@ train_throughput = None aggregated_output[model_name][device_description] = train_throughput + profiler_table_str = benchmark_output.get("profiler_table") + if profiler_table_str is not None: + profiler_tables[model_name][device_description] = profiler_table_str throughput_table = pandas.DataFrame(aggregated_output).T - +throughput_table.sort_index(axis="columns", inplace=True) +throughput_table.sort_index(axis="index", inplace=True) print(throughput_table) +table_output_lines: List[str] = [] with open( os.path.join(benchmark_artifact_folder, "table.md"), "w" ) as table_output_file: - table_output_file.write(throughput_table.to_markdown()) + table_output_lines.append(throughput_table.to_markdown()) + + for model_name, profiler_table_dict in profiler_tables.items(): + table_output_lines.append("\n## {}".format(model_name)) + for device_description, profiler_table_str in profiler_table_dict.items(): + table_output_lines.append("### {}".format(device_description)) + table_output_lines.append("```\n{}\n```".format(profiler_table_str)) -print(example_output.get("profiler_table")) + table_output_file.write("\n".join(table_output_lines)) From 79fd79bc4308b23d579241ea94298d948b711598 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Fri, 15 Mar 2024 12:06:28 -0400 Subject: [PATCH 37/89] get_lora_model_from_base_model: enabled peft for models loaded via low_cpu_mem. More investigation might be needed. --- vectorlm/utils/model_utils.py | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/vectorlm/utils/model_utils.py b/vectorlm/utils/model_utils.py index 8cfc333..7024e06 100644 --- a/vectorlm/utils/model_utils.py +++ b/vectorlm/utils/model_utils.py @@ -27,6 +27,28 @@ ) +def _is_bfloat_available() -> bool: + """ + Return whether bfloat is supported for the + current CUDA device. + + Returns: + -------- + bool. True if bfloat is supported. + """ + cuda_capability = torch.cuda.get_device_capability() + cuda_capability_str = "{}.{}".format(*cuda_capability) + if cuda_capability[0] >= 8.0: + print("Hardware capability {}; bfloat is supported".format(cuda_capability_str)) + return True + + else: + print( + "Hardware capability {}; bfloat isn't supported".format(cuda_capability_str) + ) + return False + + def get_lora_model_from_base_model( base_model: nn.Module, peft_config_dict: Dict ) -> PeftModel: @@ -42,7 +64,17 @@ def get_lora_model_from_base_model( task_type = getattr(TaskType, task_type_str) lora_config = LoraConfig(**{**peft_config_dict, "task_type": task_type}) + # See github.com/pytorch/pytorch/pull/102212 + base_model.load_state_dict(base_model.state_dict(), assign=True) + if _is_bfloat_available(): + base_model = base_model.bfloat16() + else: + base_model = base_model.half() + + assert isinstance(base_model, PreTrainedModel) lora_model = get_peft_model(base_model, lora_config) + lora_model.print_trainable_parameters() + return lora_model @@ -139,14 +171,12 @@ def load_model_and_tokenizer( if not low_cpu_mem_usage or local_rank == 0: model = AutoModelForCausalLM.from_pretrained( path, - low_cpu_mem_usage=True, **model_args, ) else: with torch.device("meta"): model = AutoModelForCausalLM.from_pretrained( path, - low_cpu_mem_usage=True, **model_args, ) From 5c253977a4fc70d6e7098ca1c70b493f75447160 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Fri, 15 Mar 2024 12:14:35 -0400 Subject: [PATCH 38/89] model_utils: revised dtype handling for peft-wrapped models. --- vectorlm/utils/model_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vectorlm/utils/model_utils.py b/vectorlm/utils/model_utils.py index 7024e06..2bbf90f 100644 --- a/vectorlm/utils/model_utils.py +++ b/vectorlm/utils/model_utils.py @@ -66,15 +66,15 @@ def get_lora_model_from_base_model( # See github.com/pytorch/pytorch/pull/102212 base_model.load_state_dict(base_model.state_dict(), assign=True) + lora_model = get_peft_model(base_model, lora_config) + if _is_bfloat_available(): - base_model = base_model.bfloat16() + lora_model = lora_model.bfloat16() else: - base_model = base_model.half() + lora_model = lora_model.half() - assert isinstance(base_model, PreTrainedModel) - lora_model = get_peft_model(base_model, lora_config) + assert isinstance(lora_model, PeftModel) lora_model.print_trainable_parameters() - return lora_model From c19de829d3d18abab47d31e7ddd8abde76be0d95 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Fri, 15 Mar 2024 13:00:35 -0400 Subject: [PATCH 39/89] parse_benchmark: implemented sorting of profiler table output. launch_benchmark: revised default run time limit. --- profiling/launch_benchmark.py | 2 +- profiling/parse_benchmark.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/profiling/launch_benchmark.py b/profiling/launch_benchmark.py index b982d52..e96e4bf 100644 --- a/profiling/launch_benchmark.py +++ b/profiling/launch_benchmark.py @@ -37,7 +37,7 @@ "partition": ["t4v2", "a40", "a100"], } -slurm_flags_extra = {"time": "00:15:00", "qos": qos_selected} +slurm_flags_extra = {"time": "00:30:00", "qos": qos_selected} slurm_pos_args_options = [["profiling/launch_lora_benchmark.sh"], model_list] timestamp = int(time.time()) diff --git a/profiling/parse_benchmark.py b/profiling/parse_benchmark.py index e967020..149e80d 100644 --- a/profiling/parse_benchmark.py +++ b/profiling/parse_benchmark.py @@ -81,9 +81,14 @@ ) as table_output_file: table_output_lines.append(throughput_table.to_markdown()) - for model_name, profiler_table_dict in profiler_tables.items(): + model_names = sorted(list(profiler_tables.keys())) + for model_name in model_names: table_output_lines.append("\n## {}".format(model_name)) - for device_description, profiler_table_str in profiler_table_dict.items(): + profiler_table_dict = profiler_tables[model_name] + device_descriptions = sorted(list(profiler_table_dict.keys())) + + for device_description in device_descriptions: + profiler_table_str = profiler_table_dict[device_description] table_output_lines.append("### {}".format(device_description)) table_output_lines.append("```\n{}\n```".format(profiler_table_str)) From 7e13cde7b08b274fdf1e4febb4ff7c4f6badc01b Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Fri, 15 Mar 2024 13:48:49 -0400 Subject: [PATCH 40/89] Merged example_lora into examples/llama_example.pu --- example_lora.py | 162 ------------------------------------------------ 1 file changed, 162 deletions(-) delete mode 100644 example_lora.py diff --git a/example_lora.py b/example_lora.py deleted file mode 100644 index a6b298d..0000000 --- a/example_lora.py +++ /dev/null @@ -1,162 +0,0 @@ -# Renamed from examples/llama_example.py -import argparse -import math -import os -import sys -from argparse import Namespace - -import torch -import torch.distributed as dist -from torch.optim import AdamW -from tqdm import tqdm -from transformers import set_seed -from transformers.models.llama.modeling_llama import LlamaDecoderLayer -from peft.utils.other import fsdp_auto_wrap_policy - -from vectorlm.dataset import Dataset -from vectorlm.trainer import Trainer -from vectorlm.utils.data_utils import Config -from vectorlm.utils.misc_utils import cleanup, setup, wandb_setup -from vectorlm.utils.model_utils import ( - hook_activation_checkpointing, - initialize_lora_model_and_tokenizer, - shard_model, -) -from vectorlm.utils.optimizer_utils import get_custom_scheduler -from vectorlm.utils.save_utils import save_consolidated_model - - -def parse_args() -> Namespace: - """Parse command-line arguments. - - Returns - ------- - The parsed arguments. - """ - parser = argparse.ArgumentParser() - parser.add_argument( - "--yaml_path", - default="configs/config.yaml", - required=False, - ) - return parser.parse_args() - - -def main(config: Config) -> None: - """Define the main calling function.""" - training_args = config.train_parameters - - # set a seed - set_seed(training_args.seed) - - # set CUDA related dependencies - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - - print(f"Rank: {rank}, World size: {world_size}") - if dist.is_initialized(): - torch.cuda.set_device(local_rank) - torch.cuda.empty_cache() - - # setup wandb - if rank == 0: - wandb_setup(config, **config.wandb_config) - dist.barrier() - - # load model and tokenizer - state_dict_path = getattr(config, "state_dict", None) - - model, tokenizer = initialize_lora_model_and_tokenizer( - config.model, - training_args.use_mp, - training_args.use_flash_attention, - training_args.max_seq_len, - config.lora_peft_config, - ) - - model = shard_model( - model, - LlamaDecoderLayer, - training_args.use_mp, - training_args.use_activation_checkpointing, - training_args.sharding_strategy, - ) - - if training_args.use_activation_checkpointing: - hook_activation_checkpointing(model, LlamaDecoderLayer) - - # load dataset - dataset = Dataset( - config=config.dataset, - tokenizer=tokenizer, - ) - - # instantiate trainer - trainer = Trainer( - config=training_args, - enable_wandb_logging=config.enable_wandb_logging, - original_dataset_length=dataset.original_length, - ) - - # load optimizer - optimizer = AdamW( - model.parameters(), - **training_args.optimizer, - ) - - # load lr scheduler - lr_scheduler = get_custom_scheduler( - training_args.lr_scheduler_type, - optimizer, - math.ceil( - trainer.num_update_steps_per_epoch * training_args.warmup_ratio, - ), - trainer.max_steps, - ) - - trainer.prepare_trainer( - model, - tokenizer, - dataset, - optimizer, - lr_scheduler, - ) - - # TODO: support restoring LoRA fine-tuning - trainer.dataset.setup_dataloaders() - checkpointed_epoch = trainer.find_checkpoint(training_args.output_dir) - - for epoch in range(checkpointed_epoch, training_args.epochs): - trainer.model.train() - train_dl_iterator = iter(dataset.train_dataloader) - for _ in tqdm( - range(len(dataset.train_dataloader)), - disable=rank != 0, - file=sys.__stdout__, - ): - batch = next(train_dl_iterator) - trainer.step(batch, epoch) - - if epoch == training_args.epochs - 1: - hf_save_dir = os.path.join(training_args.output_dir, "final-model") - else: - hf_save_dir = os.path.join( - training_args.output_dir, - "checkpoints", - f"epoch_{epoch}", - "end-epoch-model", - ) - save_consolidated_model(trainer.model, hf_save_dir, rank) - if rank == 0: - tokenizer.save_pretrained(hf_save_dir) - - dataset.reset_dataloaders() - - -if __name__ == "__main__": - args = parse_args() - config = Config(yaml_path=args.yaml_path) - setup(config.train_parameters.output_dir) - main(config) - cleanup() From 28d4edef37d0580f709fc0e00b2e303ddc6a96e6 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Fri, 15 Mar 2024 13:52:04 -0400 Subject: [PATCH 41/89] Added instructions related to parse_benchmark --- profiling/README.md | 9 ++++----- profiling/parse_benchmark.py | 7 ++++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/profiling/README.md b/profiling/README.md index 143379b..886cfcd 100644 --- a/profiling/README.md +++ b/profiling/README.md @@ -13,9 +13,8 @@ $ python3 launch_benchmark.py # to accept and automatically invoke the comands. ``` -After the SLURM jobs complete, profiler output can be found under `/data/benchmark`. +After the SLURM jobs complete, profiler output can be found under `data/benchmark`. Invoke the following the to generate a Markdown summary of the results: - -## TODO - -Add script for automatically parsing profiler output files. \ No newline at end of file +```bash +$ python3 profiling/parse_benchmark.py --folder data/benchmark +``` diff --git a/profiling/parse_benchmark.py b/profiling/parse_benchmark.py index 149e80d..e62daae 100644 --- a/profiling/parse_benchmark.py +++ b/profiling/parse_benchmark.py @@ -76,9 +76,8 @@ print(throughput_table) table_output_lines: List[str] = [] -with open( - os.path.join(benchmark_artifact_folder, "table.md"), "w" -) as table_output_file: +markdown_output_path = os.path.join(benchmark_artifact_folder, "table.md") +with open(markdown_output_path, "w") as table_output_file: table_output_lines.append(throughput_table.to_markdown()) model_names = sorted(list(profiler_tables.keys())) @@ -93,3 +92,5 @@ table_output_lines.append("```\n{}\n```".format(profiler_table_str)) table_output_file.write("\n".join(table_output_lines)) + +print("\nWriting summary to {}".format(markdown_output_path)) From a863ed297d1ed823220fe8ff4ab4db50bf862ced Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Fri, 15 Mar 2024 16:28:24 -0400 Subject: [PATCH 42/89] parse_benchmark: implemented aggregation across repeated metrics. --- profiling/parse_benchmark.py | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/profiling/parse_benchmark.py b/profiling/parse_benchmark.py index e62daae..53edd23 100644 --- a/profiling/parse_benchmark.py +++ b/profiling/parse_benchmark.py @@ -8,10 +8,30 @@ import os import json import glob -from typing import List +from typing import List, Dict, TypeVar import pandas +V = TypeVar("V", Dict, str, int) + + +def _reduce_metric(new_value: V, previous_value: V) -> V: + """ + Recursively reduce values. + + + """ + if isinstance(new_value, (float, int)): + return new_value + previous_value + elif isinstance(new_value, dict) and isinstance(previous_value, dict): + for k in previous_value.keys(): + if k in new_value.keys(): + previous_value[k] = _reduce_metric(new_value[k], previous_value[k]) + + return previous_value + else: + return new_value + parser = argparse.ArgumentParser() parser.add_argument("--folder", default="data/benchmark/") @@ -37,13 +57,16 @@ for raw_benchmark in raw_benchmarks: benchmark_output = {} - # Need to implement alternative reducing method - # string: most recent - # number: summation for line in raw_benchmark: name = line["name"] value = line["value"] - benchmark_output[name] = value + previous_value = benchmark_output.get(name) + if previous_value is not None: + new_value = _reduce_metric(value, previous_value) + else: + new_value = value + + benchmark_output[name] = new_value model_name = benchmark_output.get("model_name") if model_name is None: From eb3721afb90868ee695110ac9ce2599f90b9c51e Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Tue, 9 Apr 2024 15:06:14 -0400 Subject: [PATCH 43/89] Implemented non-LoRA profiling and benchmarking. --- configs/config.yaml | 4 +- examples/launch.sh | 2 +- examples/launch_lora.sh | 2 +- examples/launch_lora_one_gpu.sh | 2 +- profiling/README.md | 2 +- profiling/benchmark.py | 8 ++- profiling/configs/benchmark.yaml | 60 ++++++++++++++++++++ profiling/launch_benchmark.py | 24 ++++++-- profiling/launch_lora_benchmark.sh | 28 --------- profiling/parse_benchmark.py | 91 +++++++++++++++++++++++++----- vectorlm/dataset.py | 5 +- vectorlm/trainer.py | 6 +- vectorlm/utils/convert_to_hf.py | 4 +- vectorlm/utils/data_utils.py | 22 +++++--- vectorlm/utils/misc_utils.py | 3 +- vectorlm/utils/model_utils.py | 31 ++++++++-- 16 files changed, 222 insertions(+), 72 deletions(-) create mode 100644 profiling/configs/benchmark.yaml delete mode 100644 profiling/launch_lora_benchmark.sh diff --git a/configs/config.yaml b/configs/config.yaml index b7bc77f..a2a03e7 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -20,7 +20,8 @@ train_parameters: use_flash_attention: True low_cpu_mem_usage: True - # # Uncomment below to enable LoRA + # LoRA config: uncomment the block below to enable LoRA + # lora_peft_config: # task_type: CAUSAL_LM # inference_mode: False @@ -28,6 +29,7 @@ train_parameters: # lora_alpha: 32 # lora_dropout: 0.1 + # Gradient norm clipping max_grad_norm: 1 gradient_accumulation_steps: 4 diff --git a/examples/launch.sh b/examples/launch.sh index 92f0ad1..8dac6c1 100644 --- a/examples/launch.sh +++ b/examples/launch.sh @@ -23,4 +23,4 @@ export LOGLEVEL=INFO export PYTHONFAULTHANDLER=1 # export CUDA_LAUNCH_BLOCKING=0 -torchrun --nnodes=1 --nproc-per-node=${SLURM_STEP_GPUS} llama_example.py --yaml_path ../configs/config.yaml +torchrun --nnodes=1 --nproc-per-node=${SLURM_GPUS_ON_NODE} llama_example.py --yaml_path ../configs/config.yaml diff --git a/examples/launch_lora.sh b/examples/launch_lora.sh index 76f68e1..250daef 100644 --- a/examples/launch_lora.sh +++ b/examples/launch_lora.sh @@ -23,4 +23,4 @@ export LOGLEVEL=INFO export PYTHONFAULTHANDLER=1 # export CUDA_LAUNCH_BLOCKING=0 -torchrun --nnodes=1 --nproc-per-node=2 example_lora.py --yaml_path configs/config-lora.yaml +torchrun --nnodes=1 --nproc-per-node=${SLURM_GPUS_ON_NODE} example_lora.py --yaml_path configs/config-lora.yaml diff --git a/examples/launch_lora_one_gpu.sh b/examples/launch_lora_one_gpu.sh index 4390781..b030854 100644 --- a/examples/launch_lora_one_gpu.sh +++ b/examples/launch_lora_one_gpu.sh @@ -1,7 +1,7 @@ #!/bin/bash #SBATCH --job-name=llama7b-2-lora #SBATCH --nodes=1 -#SBATCH --mem=0 +#SBATCH --mem=32GB #SBATCH --ntasks-per-node=1 #SBATCH --cpus-per-gpu=6 #SBATCH --gres=gpu:1 diff --git a/profiling/README.md b/profiling/README.md index 886cfcd..47d87b7 100644 --- a/profiling/README.md +++ b/profiling/README.md @@ -10,7 +10,7 @@ $ python3 launch_benchmark.py # The launcher script will print a list of # SLURM commands it plans to run. Press ENTER -# to accept and automatically invoke the comands. +# to accept and automatically invoke the commands. ``` After the SLURM jobs complete, profiler output can be found under `data/benchmark`. Invoke the following the to generate a Markdown summary of the results: diff --git a/profiling/benchmark.py b/profiling/benchmark.py index 435b00b..51f6e71 100644 --- a/profiling/benchmark.py +++ b/profiling/benchmark.py @@ -15,7 +15,6 @@ from torch.profiler import ProfilerActivity from tqdm import tqdm from transformers import set_seed -from peft.utils.other import fsdp_auto_wrap_policy from vectorlm.dataset import Dataset from vectorlm.trainer import Trainer @@ -27,6 +26,7 @@ shard_model, get_submodule_by_pattern, get_lora_model_from_base_model, + get_half_precision_model, ) from vectorlm.utils.optimizer_utils import get_custom_scheduler from vectorlm.utils.save_utils import save_consolidated_model @@ -239,8 +239,13 @@ def main(config: Config, model_name: str) -> None: ) if lora_peft_config is not None: print("Enabling LoRA Wrapper.") + write_metrics("peft_method", "lora") model = get_lora_model_from_base_model(model, lora_peft_config) + else: + write_metrics("peft_method", "full_rank") + + model = get_half_precision_model(model) decoder_layer_module = get_submodule_by_pattern(model, r"DecoderLayer$") if decoder_layer_module is None: @@ -332,6 +337,7 @@ def main(config: Config, model_name: str) -> None: batch = next(train_dl_iterator) trainer.step(batch, epoch) profile_handle.step() + write_metrics("torch.cuda.utilization", torch.cuda.utilization()) if epoch == training_args.epochs - 1: with track_time("save_final"): diff --git a/profiling/configs/benchmark.yaml b/profiling/configs/benchmark.yaml new file mode 100644 index 0000000..76c3b8d --- /dev/null +++ b/profiling/configs/benchmark.yaml @@ -0,0 +1,60 @@ +enable_wandb_logging: True + +wandb_config: + project: vector-lm-verify + name: benchmark-lora + +train_parameters: + output_dir: /dev/shm/lora-benchmark + max_seq_len: 128 + epochs: 1 + seed: 11 + + # Sharding strategy + sharding_strategy: FULL_SHARD + + # Memory + use_mp: True + use_activation_checkpointing: True + # use_flash_attention is automatically enabled + # for CUDA capability > 8.0 + low_cpu_mem_usage: True + + # Gradient norm clipping + max_grad_norm: 1 + gradient_accumulation_steps: 4 + + # Optimizer + optimizer: + lr: 2.0e-5 + weight_decay: 0.1 + betas: [0.9, 0.95] + eps: 1.0e-5 + + # Scheduler + lr_scheduler_type: cosine + warmup_ratio: 0.05 + + # Checkpointing + checkpointing_enabled: True + logging_steps: 500 + save_frequency: 0.25 + +dataset: + ignore_index: -100 + eval_bs: 8 + train_bs: 8 + train_ds: /dev/null + eval_ds: /dev/null + +dataset_preprocess: + ignore_index: -100 + dataset_format: hf + data_field: question + packing_type: partial + add_bos_eos_tokens: True + from_disk: True + load_path: data/raw/gsm8k + split: train + save_path: data/processed/gsm8k-question/train + truncate: False diff --git a/profiling/launch_benchmark.py b/profiling/launch_benchmark.py index e96e4bf..34a1ddb 100644 --- a/profiling/launch_benchmark.py +++ b/profiling/launch_benchmark.py @@ -12,41 +12,52 @@ from tqdm.auto import tqdm parser = argparse.ArgumentParser() -parser.add_argument("--qos", required=False) +parser.add_argument("--qos", required=False, default="scavenger") +parser.add_argument("--partitions", required=False, default="t4v2,a40") parser.add_argument("--max_num_jobs", required=False) launcher_args = parser.parse_args() qos_selected = launcher_args.qos max_num_jobs = launcher_args.max_num_jobs +partitions = launcher_args.partitions.split(",") model_list = [ "/model-weights/" + model_name for model_name in [ "opt-350m", + "gemma-2b", "Llama-2-7b-hf", "Llama-2-13b-hf", + "Mistral-7B-v0.1", "Mixtral-8x7B-Instruct-v0.1", ] ] +config_list = [ + "profiling/configs/lora-benchmark.yaml", + "profiling/configs/benchmark.yaml", +] + slurm_flags_options = { "nodes": [1], - "mem": [0], + "mem-per-gpu": ["16GB"], "ntasks-per-node": [1], "cpus-per-gpu": [3], "gres": ["gpu:{}".format(n) for n in [1, 2, 4, 8]], - "partition": ["t4v2", "a40", "a100"], + "partition": partitions, } -slurm_flags_extra = {"time": "00:30:00", "qos": qos_selected} +num_repeats = 2 +slurm_flags_extra = {"time": "01:00:00", "qos": qos_selected} -slurm_pos_args_options = [["profiling/launch_lora_benchmark.sh"], model_list] +slurm_pos_args_options = [["profiling/launch_benchmark.sh"], config_list, model_list] timestamp = int(time.time()) args_list: List[List[str]] = [] -for index, (flag_values, pos_args_option) in enumerate( +for index, (flag_values, pos_args_option, _) in enumerate( itertools.product( itertools.product(*(slurm_flags_options.values())), itertools.product(*slurm_pos_args_options), + range(num_repeats) ) ): args: List[str] = ["sbatch"] @@ -77,3 +88,4 @@ makedirs("data/output", exist_ok=True) for args in tqdm(args_list, ncols=75): subprocess.run(args) + diff --git a/profiling/launch_lora_benchmark.sh b/profiling/launch_lora_benchmark.sh deleted file mode 100644 index 9a57050..0000000 --- a/profiling/launch_lora_benchmark.sh +++ /dev/null @@ -1,28 +0,0 @@ -#!/bin/bash - -export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. -export NCCL_DEBUG=WARN -export NCCL_DEBUG_SUBSYS=WARN - -# export TORCH_DISTRIBUTED_DEBUG=DETAIL # Uncomment these flags for debugging communication -# export TORCH_CPP_LOG_LEVEL=INFO -export LOGLEVEL=INFO -export PYTHONFAULTHANDLER=1 -# export CUDA_LAUNCH_BLOCKING=0 - -source ~/vectorlm/env/bin/activate -export PYTHONPATH=$PYTHONPATH:`pwd` - -nvidia-smi -export num_gpus=`nvidia-smi -L | wc -l` -echo num_gpus: ${num_gpus} -echo model: $1 - -torchrun \ ---nnodes=1 \ ---nproc-per-node=${num_gpus} profiling/benchmark.py \ ---yaml_path profiling/configs/lora-benchmark.yaml \ ---model_name $1 - -# # clean up benchmarking artifacts as ops have requested -rm -rf /dev/shm/lora-benchmark \ No newline at end of file diff --git a/profiling/parse_benchmark.py b/profiling/parse_benchmark.py index 53edd23..77303bb 100644 --- a/profiling/parse_benchmark.py +++ b/profiling/parse_benchmark.py @@ -8,21 +8,49 @@ import os import json import glob -from typing import List, Dict, TypeVar +from typing import List, Dict, TypeVar, Tuple, List, Union +import numpy as np +from dataclasses import dataclass import pandas -V = TypeVar("V", Dict, str, int) +Numbers = Union[int, float] +V = TypeVar("V") +Aggregator = TypeVar("Aggregator") +@dataclass +class RunningAverage: + running_count: int = 0 + running_sum: Union[Numbers, np.ndarray] = None -def _reduce_metric(new_value: V, previous_value: V) -> V: - """ - Recursively reduce values. + def add(self, observation): + self.running_count += 1 + if self.running_sum is None: + self.running_sum = observation + else: + self.running_sum += observation + + def get_average(self) -> Union[Numbers, List[Numbers]]: + if self.running_count == 0: + return None + + average = self.running_sum / self.running_count + if hasattr(average, "tolist"): + return average.tolist() + + return average +def _reduce_metric(new_value, previous_value): + """ + Recursively reduce values. """ if isinstance(new_value, (float, int)): - return new_value + previous_value + if not isinstance(previous_value, list): + previous_value = [previous_value] + + return [*previous_value, new_value] + elif isinstance(new_value, dict) and isinstance(previous_value, dict): for k in previous_value.keys(): if k in new_value.keys(): @@ -33,6 +61,26 @@ def _reduce_metric(new_value: V, previous_value: V) -> V: return new_value +def get_quantiles(values: List[Numbers]) -> np.ndarray: + """ + Given a list of numeraical values, + return (min, 25%, 50%, 75%, and max). + + Params: + values: list of numerical values, must be non-empty. + + Returns: + np.ndarray. + """ + output_list = [ + np.min(values), + *[np.percentile(values, q) for q in [0.25, 0.5, 0.75]], + np.max(values), + ] + + return np.asarray(output_list) + + parser = argparse.ArgumentParser() parser.add_argument("--folder", default="data/benchmark/") args = parser.parse_args() @@ -40,6 +88,7 @@ def _reduce_metric(new_value: V, previous_value: V) -> V: # Load all benchmark result jsonl files benchmark_jsonl_list = glob.glob("*.jsonl", root_dir=benchmark_artifact_folder) +print(benchmark_jsonl_list) raw_benchmarks = [] for jsonl_filename in benchmark_jsonl_list: jsonl_path = os.path.join(benchmark_artifact_folder, jsonl_filename) @@ -52,7 +101,8 @@ def _reduce_metric(new_value: V, previous_value: V) -> V: raw_benchmarks.append(benchmark_content) # (model_name, device) -aggregated_output = defaultdict(dict) +benchmarked_combinations = set() +aggregated_output: Dict[Tuple[str, str], RunningAverage] = defaultdict(lambda: RunningAverage()) profiler_tables = defaultdict(dict) for raw_benchmark in raw_benchmarks: benchmark_output = {} @@ -75,25 +125,38 @@ def _reduce_metric(new_value: V, previous_value: V) -> V: model_name = model_name.split("/")[-1] source_filename = benchmark_output["_source"] + peft_method = benchmark_output.get("peft_method") + if peft_method is None: + continue + device_info = benchmark_output["device_info"] device_name = device_info["device_name"] world_size = device_info["world_size"] - device_description = "{} x{}".format(device_name, world_size) + device_description = "{} x{} {}".format(device_name, world_size, peft_method) train_step = benchmark_output.get("train_step") if train_step is not None: - train_throughput = ( - world_size * train_step["num_tokens"] / train_step["time_elapsed"] + num_tokens = np.asarray(train_step["num_tokens"]) + time_elapsed = np.asarray(train_step["time_elapsed"]) + train_throughput = get_quantiles( + world_size * num_tokens / time_elapsed ) - else: - train_throughput = None + aggregated_output[(model_name, device_description)].add(train_throughput) - aggregated_output[model_name][device_description] = train_throughput + benchmarked_combinations.add((model_name, device_description)) profiler_table_str = benchmark_output.get("profiler_table") if profiler_table_str is not None: profiler_tables[model_name][device_description] = profiler_table_str -throughput_table = pandas.DataFrame(aggregated_output).T + +aggregated_output_nested = defaultdict(dict) +for combination in benchmarked_combinations: + model_name, device_description = combination + throughput = aggregated_output[combination].get_average() + aggregated_output_nested[model_name][device_description] = throughput + + +throughput_table = pandas.DataFrame(aggregated_output_nested) throughput_table.sort_index(axis="columns", inplace=True) throughput_table.sort_index(axis="index", inplace=True) print(throughput_table) diff --git a/vectorlm/dataset.py b/vectorlm/dataset.py index b383803..ac40f34 100644 --- a/vectorlm/dataset.py +++ b/vectorlm/dataset.py @@ -74,8 +74,9 @@ def set_processed_ids(self, ids: list[int]) -> None: def load_datasets(self) -> None: """Load datasets into memory.""" - dirs_passed = self.config.get("train_ds", "") and \ - self.config.get("eval_ds", "") + dirs_passed = self.config.get("train_ds", "") and self.config.get( + "eval_ds", "" + ) if not dirs_passed: msg = "`train_ds` and `eval_ds` are missing from config." diff --git a/vectorlm/trainer.py b/vectorlm/trainer.py index 9c024fe..bdd6c73 100644 --- a/vectorlm/trainer.py +++ b/vectorlm/trainer.py @@ -273,7 +273,6 @@ def train_step(self, batch: dict[str, torch.Tensor], epoch: int) -> float: ids = batch.pop("id").to(torch.cuda.current_device()) batch["input_ids"] = batch["input_ids"].type(torch.LongTensor) batch["labels"] = batch["labels"].type(torch.LongTensor) - batch = {k: v.to(torch.cuda.current_device()) for k, v in batch.items()} self.dataset.update_processed_ids(ids) if (self.tr_step + 1) % self.gas != self.gas - 1: @@ -336,7 +335,10 @@ def eval_step(self, epoch: int) -> float: batch["input_ids"] = batch["input_ids"].type(torch.LongTensor) num_tokens = len(batch["input_ids"].flatten()) batch["labels"] = batch["labels"].type(torch.LongTensor) - batch = {k: v.to(torch.cuda.current_device()) for k, v in batch.items()} + batch = { + k: v.to(torch.cuda.current_device()) + for k, v in batch.items() + } with self.timer_handle("eval_step", {"num_tokens": num_tokens}): out = self.model(**batch) diff --git a/vectorlm/utils/convert_to_hf.py b/vectorlm/utils/convert_to_hf.py index c4f2405..4c62811 100644 --- a/vectorlm/utils/convert_to_hf.py +++ b/vectorlm/utils/convert_to_hf.py @@ -20,6 +20,7 @@ def parse_args() -> Namespace: parser.add_argument("--config_path", default="configs/config.yaml") return parser.parse_args() + def converter(config: Config) -> None: """Define main converting function. @@ -42,13 +43,14 @@ def converter(config: Config) -> None: config.model, True, False, - 2048, # doesn't matter so hard-coded. + 2048, # doesn't matter so hard-coded. ) model.load_state_dict(state_dict) model.save_pretrained( os.path.join(config.train_parameters.output_dir, "hf-model"), ) + if __name__ == "__main__": args = parse_args() config = Config(args.config_path) diff --git a/vectorlm/utils/data_utils.py b/vectorlm/utils/data_utils.py index ddf021f..f3cc839 100644 --- a/vectorlm/utils/data_utils.py +++ b/vectorlm/utils/data_utils.py @@ -38,7 +38,7 @@ def load_yaml(self) -> None: with open(self.yaml_path) as in_path: _config = yaml.safe_load(in_path) - for k,v in _config.items(): + for k, v in _config.items(): self.__setattr__(k, v) if self.to_box: self._to_box() @@ -107,10 +107,14 @@ def __call__( """ batch = {} keys = ["input_ids", "labels"] - input_ids, labels = tuple([ - torch.tensor( - instance[key][0:self.max_seq_len], - ) for instance in instances] for key in keys + input_ids, labels = tuple( + [ + torch.tensor( + instance[key][0 : self.max_seq_len], + ) + for instance in instances + ] + for key in keys ) batch["id"] = torch.tensor( [instance["id"] for instance in instances], @@ -121,10 +125,14 @@ def __call__( labels = self._reverse_tensor(labels) input_ids = torch.nn.utils.rnn.pad_sequence( - input_ids, batch_first=True, padding_value=self.pad_token_id, + input_ids, + batch_first=True, + padding_value=self.pad_token_id, ) labels = torch.nn.utils.rnn.pad_sequence( - labels, batch_first=True, padding_value=self.ignore_index, + labels, + batch_first=True, + padding_value=self.ignore_index, ) if self.padding_side == "left": diff --git a/vectorlm/utils/misc_utils.py b/vectorlm/utils/misc_utils.py index 6e8cc1a..081deb2 100644 --- a/vectorlm/utils/misc_utils.py +++ b/vectorlm/utils/misc_utils.py @@ -12,7 +12,8 @@ def setup(final_model_dir: str) -> None: """Initialize the process group and create directories.""" os.makedirs( - os.path.join(final_model_dir, "final-model"), exist_ok=True, + os.path.join(final_model_dir, "final-model"), + exist_ok=True, ) dist.init_process_group("nccl") diff --git a/vectorlm/utils/model_utils.py b/vectorlm/utils/model_utils.py index 2bbf90f..b4a09e9 100644 --- a/vectorlm/utils/model_utils.py +++ b/vectorlm/utils/model_utils.py @@ -49,6 +49,25 @@ def _is_bfloat_available() -> bool: return False +def get_half_precision_model(model: nn.Module) -> nn.Module: + """ + Cast model to appropriate half-precision format + depending on GPU hardware support. + + Args: + ---- + + model: nn.Module to cast. + + Returns: + ------- + + nn.Module + """ + model = model.bfloat16() + return model + + def get_lora_model_from_base_model( base_model: nn.Module, peft_config_dict: Dict ) -> PeftModel: @@ -67,17 +86,19 @@ def get_lora_model_from_base_model( # See github.com/pytorch/pytorch/pull/102212 base_model.load_state_dict(base_model.state_dict(), assign=True) lora_model = get_peft_model(base_model, lora_config) - - if _is_bfloat_available(): - lora_model = lora_model.bfloat16() - else: - lora_model = lora_model.half() + lora_model = get_half_precision_model(lora_model) assert isinstance(lora_model, PeftModel) lora_model.print_trainable_parameters() return lora_model +def _assert_parameter_shapes(model: nn.Module): + for name, parameter in model.named_parameters(): + print(name, parameter.dtype) + assert parameter.dtype is torch.float16 + + def load_peft_model_and_tokenizer( path: str, use_mp: bool, From 37f5dec4c24308c29f879b7c29be05fc55a6aa9d Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Thu, 11 Apr 2024 15:23:57 -0400 Subject: [PATCH 44/89] Various static typechecking and formatting fixes. --- examples/__init__.py | 0 examples/llama_example.py | 9 +- profiling/__init__.py | 0 profiling/benchmark.py | 248 +++++++++++++++++++----------- profiling/launch_benchmark.py | 37 +++-- profiling/parse_benchmark.py | 174 +++++++++++++-------- pyproject.toml | 8 + vectorlm/dataset.py | 2 +- vectorlm/tests/test_dataset.py | 36 +++-- vectorlm/tests/test_modelling.py | 175 +++++++++++---------- vectorlm/trainer.py | 31 ++-- vectorlm/utils/convert_to_hf.py | 3 + vectorlm/utils/misc_utils.py | 2 +- vectorlm/utils/model_utils.py | 142 ++++------------- vectorlm/utils/optimizer_utils.py | 10 +- vectorlm/utils/save_utils.py | 4 +- 16 files changed, 492 insertions(+), 389 deletions(-) create mode 100644 examples/__init__.py create mode 100644 profiling/__init__.py diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/llama_example.py b/examples/llama_example.py index 3d284ec..7d684fc 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import math import os @@ -9,7 +11,6 @@ from torch.optim import AdamW from tqdm import tqdm from transformers import set_seed -from transformers.models.llama.modeling_llama import LlamaDecoderLayer from vectorlm.dataset import Dataset from vectorlm.trainer import Trainer @@ -74,13 +75,15 @@ def main(config: Config) -> None: training_args.low_cpu_mem_usage, ) - lora_peft_config = getattr(config.train_parameters, "lora_peft_config", None) + lora_peft_config = getattr( + config.train_parameters, "lora_peft_config", None, + ) if lora_peft_config is not None: model = get_lora_model_from_base_model(model, lora_peft_config) decoder_layer_module = get_submodule_by_pattern(model, r"DecoderLayer$") model = shard_model( - model, + model.bfloat16(), decoder_layer_module, training_args.use_mp, training_args.use_activation_checkpointing, diff --git a/profiling/__init__.py b/profiling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/profiling/benchmark.py b/profiling/benchmark.py index 51f6e71..e169e53 100644 --- a/profiling/benchmark.py +++ b/profiling/benchmark.py @@ -1,3 +1,5 @@ +from __future__ import annotations + # Renamed from examples/llama_example.py import argparse import contextlib @@ -7,30 +9,37 @@ import sys import time from argparse import Namespace -from typing import Any, Dict, Optional +from typing import Any, Generator import torch import torch.distributed as dist from torch.optim import AdamW from torch.profiler import ProfilerActivity from tqdm import tqdm -from transformers import set_seed +from transformers import PreTrainedTokenizer, set_seed from vectorlm.dataset import Dataset from vectorlm.trainer import Trainer from vectorlm.utils.data_utils import Config from vectorlm.utils.misc_utils import cleanup, setup, wandb_setup from vectorlm.utils.model_utils import ( + get_half_precision_model, + get_lora_model_from_base_model, + get_submodule_by_pattern, hook_activation_checkpointing, load_model_and_tokenizer, shard_model, - get_submodule_by_pattern, - get_lora_model_from_base_model, - get_half_precision_model, ) from vectorlm.utils.optimizer_utils import get_custom_scheduler from vectorlm.utils.save_utils import save_consolidated_model +JSONSerializable = str | dict[str, Any] | list[str] | float | None +_MIN_FLASH_ATTENTION_CUDA_CAPABILITY = 8 + +# Cap value ot tokenizer.model_max_length to this value, +# unless overridden when instantiating the benchmarking dataset. +_MAX_SEQ_LENGTH = 65536 + def parse_args() -> Namespace: """Parse command-line arguments. @@ -38,6 +47,7 @@ def parse_args() -> Namespace: Returns ------- The parsed arguments. + """ parser = argparse.ArgumentParser() parser.add_argument( @@ -49,6 +59,14 @@ def parse_args() -> Namespace: "--model_name", required=True, ) + parser.add_argument( + "--num_train_examples", + default=10000, + ) + parser.add_argument( + "--num_eval_examples", + default=1000, + ) return parser.parse_args() @@ -56,14 +74,17 @@ def parse_args() -> Namespace: launch_time = time.time() os.makedirs("data/benchmark", exist_ok=True) os.makedirs("data/trace", exist_ok=True) -output_path = "data/benchmark/{}.jsonl".format(launch_time) -profiler_output_path = "data/trace/{}.json".format(launch_time) +output_path = f"data/benchmark/{launch_time}.jsonl" +profiler_output_path = f"data/trace/{launch_time}.json" -def write_metrics(metric_name: str, value: Optional[Any] = None) -> None: - """ - Write metric and time elapsed to output file. - Write to disk only if process rank is 0. +def write_metrics( + metric_name: str, + value: JSONSerializable = None, +) -> None: + """Write metric and time elapsed to output file. + + This function writes to disk only if process rank is 0. Params: metric_name: string indicating type of metric @@ -84,61 +105,74 @@ def write_metrics(metric_name: str, value: Optional[Any] = None) -> None: @contextlib.contextmanager -def track_time(task_name: str, extra_info: Dict[str, Any] = {}): +def track_time( + task_name: str, + extra_info: dict[str, Any] | None = None, +) -> Generator[None, None, None]: + """Context manager for recording time spent in a code block. + + Params + ------ + task_name: str + extra_info: Optional, JSON-serializable dictionary + to include in log output. + + """ start_time = time.time() try: yield finally: time_elapsed = time.time() - start_time - write_metrics(task_name, {"time_elapsed": time_elapsed, **extra_info}) + metric_value = {"time_elapsed": time_elapsed} + if extra_info is not None: + metric_value = {**metric_value, **extra_info} + write_metrics(task_name, metric_value) -def get_device_info() -> Dict[str, str | int]: - """ - Get CUDA info as a dict. - Returns: +def get_device_info() -> dict[str, str | int]: + """Get CUDA info as a dict. + + Returns + ------- Dict including device_name and world size + """ - return dict( - device_name=torch.cuda.get_device_name(), - local_rank=int(os.environ["LOCAL_RANK"]), - rank=int(os.environ["RANK"]), - world_size=int(os.environ["WORLD_SIZE"]), - ) + return { + "device_name": torch.cuda.get_device_name(), + "local_rank": int(os.environ["LOCAL_RANK"]), + "rank": int(os.environ["RANK"]), + "world_size": int(os.environ["WORLD_SIZE"]), + } def get_is_flash_attention_supported() -> bool: - """ - Returns: - Whether Flash Attention is supported based on - the given CUDA device capability. - """ + """Determine whether flash attention is available.""" version_major, _ = torch.cuda.get_device_capability() - return version_major >= 8 + return version_major >= _MIN_FLASH_ATTENTION_CUDA_CAPABILITY -def get_slurm_env() -> Dict[str, str]: - """ - Returns a dictionary of all env var starting with "SLURM_". - """ - output = { - key: value for key, value in os.environ.items() if key.startswith("SLURM_") +def get_slurm_env() -> dict[str, str]: + """Return a dictionary of all env var starting with "SLURM_".""" + return { + key: value + for key, value in os.environ.items() + if key.startswith("SLURM_") } - return output def parse_profiler_output( profiler_output: torch.autograd.profiler.profile, -) -> Dict[str, Dict[str, str | float | int]]: - """ - Parse profiler_output to obtain dictionary of metrics. +) -> dict[str, dict[str, str | float | int]]: + """Parse profiler_output to obtain dictionary of metrics. - Returns: + Returns + ------- Dictionary mapping event name to dictionary of metrics. + """ key_average_event_list = profiler_output.key_averages() - output: Dict[str, Dict[str, str | float | int]] = {} + output: dict[str, dict[str, str | float | int]] = {} for evt in key_average_event_list: trace_name = getattr(evt, "trace_name", None) if trace_name is None: @@ -156,14 +190,17 @@ def parse_profiler_output( return output -def handle_profiler_trace(profiler_output: torch.autograd.profiler.profile): - """ - Log torch profile to disk. +def _handle_profiler_trace( + profiler_output: torch.autograd.profiler.profile, +) -> None: + """Log torch profile to disk. + This function is to be invoked as a callback for on_track_ready. Args: - ----- - profile: from Torch profiler. + ---- + profiler_output: from Torch profiler. + """ print(profiler_output) key_average_event_list = profiler_output.key_averages() @@ -176,31 +213,75 @@ def handle_profiler_trace(profiler_output: torch.autograd.profiler.profile): class BenchmarkingDataset(Dataset): + """In-memory dataset for benchmarking.""" + + def __init__( + self, + config: Config, + num_train_examples: int, + num_eval_examples: int, + tokenizer: PreTrainedTokenizer, + max_length: int | None = None, + ) -> None: + """Initialize in-memory dataset for benchmarking. + + Refer to vectorlm.dataset for details regarding config + and tokenizer. + + Params: + ------ + config: dataset config. Forwarded to vectorlm.dataset.Dataset. + num_train_examples: length of train split. + num_eval_examples: length of eval split. + tokenizer: HuggingFace tokenizer. + max_length: optional. If not specified, + fall back to tokenizer.model_max_length. + """ + self.num_train_examples = num_train_examples + self.num_eval_examples = num_eval_examples + + if max_length is not None: + self.max_length = max_length + else: + self.max_length = min(tokenizer.model_max_length, _MAX_SEQ_LENGTH) + + super().__init__(config, tokenizer) + def load_datasets(self) -> None: """Load datasets into memory.""" - self.train_ds = [ - { - "id": row_id, - "input_ids": torch.zeros(1024), - "labels": torch.zeros(1024), - "attention_mask": torch.ones(1024), - } - for row_id in range(8192) - ] - self.eval_ds = self.train_ds[: len(self.train_ds) // 10] + self.train_ds, self.eval_ds = ( + [ + { + "id": row_id, + "input_ids": torch.zeros(self.max_length), + "labels": torch.zeros(self.max_length), + "attention_mask": torch.ones(self.max_length), + } + for row_id in range(length) + ] + for length in (self.num_train_examples, self.num_eval_examples) + ) + self.original_length = math.ceil(len(self.train_ds) / self.train_bs) -def main(config: Config, model_name: str) -> None: - """Define the main calling function.""" - print("Writing metrics to {}".format(output_path)) - write_metrics("model_name", model_name) +if __name__ == "__main__": + args = parse_args() + config = Config(yaml_path=args.yaml_path) + setup(config.train_parameters.output_dir) + + print(f"Writing metrics to {output_path}") + write_metrics("model_name", args.model_name) write_metrics("config", {**config.__dict__}) write_metrics("device_info", get_device_info()) write_metrics("slurm_info", get_slurm_env()) profiler_schedule = torch.profiler.schedule( - skip_first=10, wait=5, warmup=1, active=3, repeat=2 + skip_first=10, + wait=5, + warmup=1, + active=3, + repeat=2, ) training_args = config.train_parameters @@ -225,12 +306,15 @@ def main(config: Config, model_name: str) -> None: dist.barrier() # load model and tokenizer - state_dict_path = getattr(config, "state_dict", None) - lora_peft_config = getattr(config.train_parameters, "lora_peft_config", None) + lora_peft_config = getattr( + config.train_parameters, + "lora_peft_config", + None, + ) with track_time("model_load"): model, tokenizer = load_model_and_tokenizer( - model_name, + args.model_name, training_args.use_mp, get_is_flash_attention_supported(), training_args.max_seq_len, @@ -249,8 +333,8 @@ def main(config: Config, model_name: str) -> None: decoder_layer_module = get_submodule_by_pattern(model, r"DecoderLayer$") if decoder_layer_module is None: - track_time("decoder_layer_module_is_none") - raise ValueError("decoder_layer_module is None.") + msg = "decoder_layer_module is None." + raise ValueError(msg) with track_time("model_shard"): model = shard_model( @@ -262,14 +346,6 @@ def main(config: Config, model_name: str) -> None: local_rank, training_args.low_cpu_mem_usage, ) - per_device_parameter_count = sum(p.numel() for p in model.parameters()) - track_time( - "parameter_count", - { - "per_device": per_device_parameter_count, - "total": per_device_parameter_count * world_size, - }, - ) with track_time("set_activation_checkpointing"): if training_args.use_activation_checkpointing: @@ -279,6 +355,8 @@ def main(config: Config, model_name: str) -> None: with track_time("dataset_load"): dataset = BenchmarkingDataset( config=config.dataset, + num_train_examples=args.num_train_examples, + num_eval_examples=args.num_eval_examples, tokenizer=tokenizer, ) @@ -315,7 +393,6 @@ def main(config: Config, model_name: str) -> None: lr_scheduler, ) - # TODO: support restoring LoRA fine-tuning trainer.dataset.setup_dataloaders() checkpointed_epoch = 0 @@ -323,25 +400,30 @@ def main(config: Config, model_name: str) -> None: with torch.profiler.profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], schedule=profiler_schedule, - on_trace_ready=handle_profiler_trace, + on_trace_ready=_handle_profiler_trace, ) as profile_handle: for epoch in range(checkpointed_epoch, training_args.epochs): trainer.model.train() train_dl_iterator = iter(dataset.train_dataloader) for _ in tqdm( - # range(len(dataset.train_dataloader)), - range(7 * 13), + range(args.num_train_examples), disable=rank != 0, file=sys.__stdout__, ): batch = next(train_dl_iterator) trainer.step(batch, epoch) profile_handle.step() - write_metrics("torch.cuda.utilization", torch.cuda.utilization()) + write_metrics( + "torch.cuda.utilization", + torch.cuda.utilization(), + ) if epoch == training_args.epochs - 1: with track_time("save_final"): - hf_save_dir = os.path.join(training_args.output_dir, "final-model") + hf_save_dir = os.path.join( + training_args.output_dir, + "final-model", + ) else: with track_time("save_checkpoint"): hf_save_dir = os.path.join( @@ -357,10 +439,4 @@ def main(config: Config, model_name: str) -> None: dataset.reset_dataloaders() - -if __name__ == "__main__": - args = parse_args() - config = Config(yaml_path=args.yaml_path) - setup(config.train_parameters.output_dir) - main(config, args.model_name) cleanup() diff --git a/profiling/launch_benchmark.py b/profiling/launch_benchmark.py index 34a1ddb..9374350 100644 --- a/profiling/launch_benchmark.py +++ b/profiling/launch_benchmark.py @@ -1,13 +1,12 @@ -""" -Create SLURM jobs running the LoRA benchmark. -""" +"""Create SLURM jobs running the LoRA benchmark.""" + +from __future__ import annotations import argparse import itertools import subprocess import time from os import makedirs -from typing import List from tqdm.auto import tqdm @@ -42,38 +41,43 @@ "mem-per-gpu": ["16GB"], "ntasks-per-node": [1], "cpus-per-gpu": [3], - "gres": ["gpu:{}".format(n) for n in [1, 2, 4, 8]], + "gres": [f"gpu:{n}" for n in [1, 2, 4, 8]], "partition": partitions, } num_repeats = 2 slurm_flags_extra = {"time": "01:00:00", "qos": qos_selected} -slurm_pos_args_options = [["profiling/launch_benchmark.sh"], config_list, model_list] +slurm_pos_args_options = [ + ["profiling/launch_benchmark.sh"], + config_list, + model_list, +] timestamp = int(time.time()) -args_list: List[List[str]] = [] +args_list: list[list[str]] = [] for index, (flag_values, pos_args_option, _) in enumerate( itertools.product( itertools.product(*(slurm_flags_options.values())), itertools.product(*slurm_pos_args_options), - range(num_repeats) - ) + range(num_repeats), + ), ): - args: List[str] = ["sbatch"] + args: list[str] = ["sbatch"] + log_file_path = f"data/output/{timestamp}.{index}.out" extra_flags = { **slurm_flags_extra, - "output": "data/output/{}.{}.out".format(timestamp, index), - "error": "data/output/{}.{}.out".format(timestamp, index), - "job-name": "bench-{}-{}".format(timestamp, index), + "output": log_file_path, + "error": log_file_path, + "job-name": f"bench-{timestamp}", } keys = list(slurm_flags_options.keys()) + list(extra_flags.keys()) values = list(flag_values) + list(extra_flags.values()) for key, value in zip(keys, values): if value is not None: - arg = ("--{}".format(key), str(value)) + arg = (f"--{key}", str(value)) args.extend(arg) args.extend(pos_args_option) @@ -83,9 +87,8 @@ if (max_num_jobs is not None) and index + 1 >= int(max_num_jobs): break -input("\nPress ENTER to launch {} job(s)".format(len(args_list))) +input(f"\nPress ENTER to launch {len(args_list)} job(s)") makedirs("data/output", exist_ok=True) for args in tqdm(args_list, ncols=75): - subprocess.run(args) - + subprocess.run(args, check=False) diff --git a/profiling/parse_benchmark.py b/profiling/parse_benchmark.py index 77303bb..cd7b7bf 100644 --- a/profiling/parse_benchmark.py +++ b/profiling/parse_benchmark.py @@ -1,49 +1,82 @@ -""" -Parse benchmarking results -to generate metrics overview table. -""" +"""Parse benchmarking results to generate metrics overview table.""" + +from __future__ import annotations import argparse -from collections import defaultdict -import os -import json import glob -from typing import List, Dict, TypeVar, Tuple, List, Union -import numpy as np +import json +import os +from collections import defaultdict from dataclasses import dataclass +from typing import TypeVar, Union -import pandas +import numpy as np +import pandas as pd Numbers = Union[int, float] +NumericalTypes = Union[Numbers, np.ndarray] V = TypeVar("V") Aggregator = TypeVar("Aggregator") +Numerical = TypeVar("Numerical", bound=NumericalTypes) + @dataclass class RunningAverage: + """Abstraction for tracking numbers required to compute averages. + + Params: + running_count: number of observations added + + """ + running_count: int = 0 - running_sum: Union[Numbers, np.ndarray] = None + running_sum: NumericalTypes | None = None + + def add(self, observation: NumericalTypes) -> None: + """Add observation to accumulator. - def add(self, observation): + Params + ------ + observation: must be numerical and of same type + (number or np.ndarray) as running_sum. + + """ self.running_count += 1 if self.running_sum is None: self.running_sum = observation else: self.running_sum += observation - def get_average(self) -> Union[Numbers, List[Numbers]]: - if self.running_count == 0: + def get_average(self) -> NumericalTypes | None: + """Obtain average of this accumulator. + + Returns + ------- + NumericalTypes + same type (number or np.ndarray) as self.running_sum. + + """ + if (self.running_count == 0) or (self.running_sum is None): return None - - average = self.running_sum / self.running_count - if hasattr(average, "tolist"): - return average.tolist() - return average + return self.running_sum / self.running_count -def _reduce_metric(new_value, previous_value): - """ - Recursively reduce values. +def _reduce_metric( + new_value: NumericalTypes | str | dict, + previous_value: NumericalTypes | str | dict | list, +) -> NumericalTypes | str | dict | list: + """Recursively reduce values. + + Params + ------ + new_value: value to aggregate + previous_value: aggregator + + Returns + ------- + Same type as previous value. + """ if isinstance(new_value, (float, int)): if not isinstance(previous_value, list): @@ -51,26 +84,30 @@ def _reduce_metric(new_value, previous_value): return [*previous_value, new_value] - elif isinstance(new_value, dict) and isinstance(previous_value, dict): - for k in previous_value.keys(): - if k in new_value.keys(): - previous_value[k] = _reduce_metric(new_value[k], previous_value[k]) + if isinstance(new_value, dict) and isinstance(previous_value, dict): + for k in previous_value: + if k in new_value: + previous_value[k] = _reduce_metric( + new_value[k], + previous_value[k], + ) return previous_value - else: - return new_value + return new_value -def get_quantiles(values: List[Numbers]) -> np.ndarray: - """ - Given a list of numeraical values, - return (min, 25%, 50%, 75%, and max). - Params: +def get_quantiles(values: list[Numbers]) -> np.ndarray: + """Given a list of numerical values, return (min, 25%, 50%, 75%, and max). + + Params + ------ values: list of numerical values, must be non-empty. - Returns: + Returns + ------- np.ndarray. + """ output_list = [ np.min(values), @@ -88,11 +125,10 @@ def get_quantiles(values: List[Numbers]) -> np.ndarray: # Load all benchmark result jsonl files benchmark_jsonl_list = glob.glob("*.jsonl", root_dir=benchmark_artifact_folder) -print(benchmark_jsonl_list) raw_benchmarks = [] for jsonl_filename in benchmark_jsonl_list: jsonl_path = os.path.join(benchmark_artifact_folder, jsonl_filename) - with open(jsonl_path, "r") as jsonl_file: + with open(jsonl_path) as jsonl_file: benchmark_content = [ json.loads(line) for line in jsonl_file.read().splitlines() ] @@ -100,13 +136,20 @@ def get_quantiles(values: List[Numbers]) -> np.ndarray: raw_benchmarks.append(benchmark_content) -# (model_name, device) -benchmarked_combinations = set() -aggregated_output: Dict[Tuple[str, str], RunningAverage] = defaultdict(lambda: RunningAverage()) +# Set of tuples the form (model_name, device) +benchmarked_combinations: set[tuple[str, str]] = set() +aggregated_output: dict[tuple[str, str], RunningAverage] = defaultdict( + lambda: RunningAverage(), +) profiler_tables = defaultdict(dict) + +# Aggregate benchmark files to obtain average values +# for each model-device combination. for raw_benchmark in raw_benchmarks: benchmark_output = {} + # If an entry (e.g., train_step) is logged multiple times + # in the benchmark output, aggregate these values. for line in raw_benchmark: name = line["name"] value = line["value"] @@ -126,23 +169,34 @@ def get_quantiles(values: List[Numbers]) -> np.ndarray: source_filename = benchmark_output["_source"] peft_method = benchmark_output.get("peft_method") + if peft_method == "lora" and model_name == "gemma-2b": + print(source_filename) if peft_method is None: continue - + device_info = benchmark_output["device_info"] device_name = device_info["device_name"] - world_size = device_info["world_size"] - device_description = "{} x{} {}".format(device_name, world_size, peft_method) + if isinstance(device_info["world_size"], list): + world_size = device_info["world_size"][0] + else: + world_size = device_info["world_size"] + device_description = f"({peft_method}) {device_name} x{world_size}" + # Training throughput can be noisy. Report quantiles instead of avg, + # and discard instances with only one training step logged. train_step = benchmark_output.get("train_step") if train_step is not None: num_tokens = np.asarray(train_step["num_tokens"]) time_elapsed = np.asarray(train_step["time_elapsed"]) - train_throughput = get_quantiles( - world_size * num_tokens / time_elapsed - ) - aggregated_output[(model_name, device_description)].add(train_throughput) - + if num_tokens.flatten().shape[0] > 1: + train_throughput = get_quantiles( + world_size * num_tokens / time_elapsed, + ) + aggregated_output[(model_name, device_description)].add( + train_throughput[2], + ) + + # torch profiler output in tabular format benchmarked_combinations.add((model_name, device_description)) profiler_table_str = benchmark_output.get("profiler_table") if profiler_table_str is not None: @@ -154,29 +208,19 @@ def get_quantiles(values: List[Numbers]) -> np.ndarray: model_name, device_description = combination throughput = aggregated_output[combination].get_average() aggregated_output_nested[model_name][device_description] = throughput - -throughput_table = pandas.DataFrame(aggregated_output_nested) -throughput_table.sort_index(axis="columns", inplace=True) -throughput_table.sort_index(axis="index", inplace=True) + +throughput_table = ( + pd.DataFrame(aggregated_output_nested) + .sort_index(axis="columns") + .sort_index(axis="index") +) print(throughput_table) -table_output_lines: List[str] = [] +table_output_lines: list[str] = [] markdown_output_path = os.path.join(benchmark_artifact_folder, "table.md") with open(markdown_output_path, "w") as table_output_file: table_output_lines.append(throughput_table.to_markdown()) - - model_names = sorted(list(profiler_tables.keys())) - for model_name in model_names: - table_output_lines.append("\n## {}".format(model_name)) - profiler_table_dict = profiler_tables[model_name] - device_descriptions = sorted(list(profiler_table_dict.keys())) - - for device_description in device_descriptions: - profiler_table_str = profiler_table_dict[device_description] - table_output_lines.append("### {}".format(device_description)) - table_output_lines.append("```\n{}\n```".format(profiler_table_str)) - table_output_file.write("\n".join(table_output_lines)) -print("\nWriting summary to {}".format(markdown_output_path)) +print(f"\nWriting summary to {markdown_output_path}") diff --git a/pyproject.toml b/pyproject.toml index c9ca70d..f1261d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,3 +22,11 @@ lint.ignore = [ [tool.ruff.lint.per-file-ignores] # Ignore `F401` (import violations) in all `__init__.py` files. "__init__.py" = ["F401", "D104"] + +# PyTest fixtures imported from other modules are considered +# as "unused". Ignore these in unit test modules. +# pytest.mark.usefixtures decorator are not available for fixtures. +# Some ANN001 instances are ignored in-line. +"vectorlm/tests/*" = ["F401", "F811"] +"profiling/launch_benchmark.py" = ["S603"] + diff --git a/vectorlm/dataset.py b/vectorlm/dataset.py index ac40f34..c6f042f 100644 --- a/vectorlm/dataset.py +++ b/vectorlm/dataset.py @@ -75,7 +75,7 @@ def set_processed_ids(self, ids: list[int]) -> None: def load_datasets(self) -> None: """Load datasets into memory.""" dirs_passed = self.config.get("train_ds", "") and self.config.get( - "eval_ds", "" + "eval_ds", "", ) if not dirs_passed: diff --git a/vectorlm/tests/test_dataset.py b/vectorlm/tests/test_dataset.py index 7856487..455229c 100644 --- a/vectorlm/tests/test_dataset.py +++ b/vectorlm/tests/test_dataset.py @@ -1,11 +1,15 @@ -import pytest -from vectorlm.tests.test_modelling import setup_and_teardown_torch_process_group -from profiling.benchmark import BenchmarkingDataset +"""Unit tests for the in-memory benchmarking dataset.""" +import pytest from box import Box - from transformers import AutoTokenizer +from profiling.benchmark import BenchmarkingDataset +from vectorlm.tests.test_modelling import ( + _setup_and_teardown_torch_process_group, +) + +_BATCH_TOKEN_DIMENSIONALITY = 2 dataset_config = Box( { @@ -14,26 +18,38 @@ "train_bs": 8, "train_ds": "/dev/null", "eval_ds": "/dev/null", - } + }, ) @pytest.fixture() -def benchmark_dataset(setup_and_teardown_torch_process_group): +def benchmark_dataset( + _setup_and_teardown_torch_process_group, # noqa: ANN001 +) -> BenchmarkingDataset: + """Instantiate example in-memory benchmarking dataset.""" tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") - return BenchmarkingDataset(dataset_config, tokenizer) # type: ignore + return BenchmarkingDataset( + config=dataset_config, + num_train_examples=10000, + num_eval_examples=1000, + tokenizer=tokenizer, + ) -def test_initialize_dataset(benchmark_dataset): +def test_initialize_dataset(benchmark_dataset: BenchmarkingDataset) -> None: + """Ensure that instantiating dataset does not throw an error message.""" print(benchmark_dataset) -def test_get_batch(benchmark_dataset): +def test_get_batch(benchmark_dataset: BenchmarkingDataset) -> None: + """Verify shape of dataset iterator output.""" benchmark_dataset.setup_dataloaders() dataset_iterator = iter(benchmark_dataset.train_dataloader) batch = next(dataset_iterator) for key in ["input_ids", "attention_mask"]: - assert len(batch[key].shape) == 2 # batch, tokens + assert ( + len(batch[key].shape) == _BATCH_TOKEN_DIMENSIONALITY + ) # batch, tokens print(batch) diff --git a/vectorlm/tests/test_modelling.py b/vectorlm/tests/test_modelling.py index c0d8403..a88482c 100644 --- a/vectorlm/tests/test_modelling.py +++ b/vectorlm/tests/test_modelling.py @@ -1,36 +1,34 @@ -""" -Test model loading, sharding, and forward/backward. -""" +"""Test model loading, sharding, and forward/backward.""" + +from __future__ import annotations -from collections import Counter, defaultdict import os import re +from collections import Counter, defaultdict +from typing import Any, Generator import pytest import torch import torch.distributed as dist from torch import nn -from torch.optim import AdamW -from torch.distributed.fsdp import ShardingStrategy -from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch from torch.distributed.fsdp.fully_sharded_data_parallel import ( FullyShardedDataParallel as FSDP, ) +from torch.optim import AdamW from transformers.models.opt.modeling_opt import OPTDecoderLayer from vectorlm.utils.model_utils import ( get_lora_model_from_base_model, + get_submodule_by_pattern, load_model_and_tokenizer, shard_model, - get_submodule_by_pattern, ) local_rank = int(os.environ.get("LOCAL_RANK", 0)) @pytest.fixture() -def setup_and_teardown_torch_process_group(): - # Setup +def _setup_and_teardown_torch_process_group() -> Generator[None, None, None]: dist.init_process_group( backend="nccl", init_method="tcp://localhost:25567", @@ -45,10 +43,8 @@ def setup_and_teardown_torch_process_group(): @pytest.fixture() -def lora_peft_config(): - """ - Example peft config_dict for LoRA. - """ +def lora_peft_config() -> dict[str, Any]: + """Populate example peft config_dict for LoRA.""" return { "task_type": "CAUSAL_LM", "inference_mode": False, @@ -59,43 +55,74 @@ def lora_peft_config(): @pytest.fixture() -def base_model(): +def base_model() -> torch.nn.Module: + """Instantiate example non-sharded non-peft transformer model.""" model, tokenizer = load_model_and_tokenizer( - "/model-weights/opt-350m", True, False, 1024, local_rank, True + "/model-weights/opt-350m", + True, + False, + 1024, + local_rank, + True, ) return model @pytest.fixture() -def lora_model(base_model, lora_peft_config): - lora_model = get_lora_model_from_base_model(base_model, lora_peft_config) - return lora_model +def lora_model( + base_model: torch.nn.Module, + lora_peft_config: dict[str, Any], +) -> torch.nn.Module: + """Obtain LoRA-wrapped base model.""" + return get_lora_model_from_base_model(base_model, lora_peft_config) @pytest.fixture() -def base_model_sharded(base_model, setup_and_teardown_torch_process_group): - model_sharded = shard_model( - base_model, OPTDecoderLayer, True, True, "FULL_SHARD", local_rank, True +def base_model_sharded( + base_model: torch.nn.Module, + _setup_and_teardown_torch_process_group, # noqa: ANN001 +) -> torch.nn.Module: + """Obtain FSDP-sharded base model.""" + return shard_model( + base_model, + OPTDecoderLayer, + True, + True, + "FULL_SHARD", + local_rank, + True, ) - return model_sharded @pytest.fixture() -def lora_model_sharded(lora_model, setup_and_teardown_torch_process_group): +def lora_model_sharded( + lora_model: torch.nn.Module, + _setup_and_teardown_torch_process_group, # noqa: ANN001 +) -> torch.nn.Module: + """Obtain FSDP-sharded LoRA model.""" model_sharded = shard_model( - lora_model, OPTDecoderLayer, True, True, "FULL_SHARD", local_rank, True + lora_model, + OPTDecoderLayer, + True, + True, + "FULL_SHARD", + local_rank, + True, ) return FSDP(model_sharded, device_id=torch.cuda.current_device()) @pytest.fixture() -def optimizer_lora_sharded(lora_model_sharded): - optimizer = AdamW(lora_model_sharded.parameters()) - return optimizer +def optimizer_lora_sharded( + lora_model_sharded: torch.nn.Module, +) -> torch.optim.AdamW: + """Instantiate optimizer for sharded LoRA model.""" + return AdamW(lora_model_sharded.parameters()) @pytest.fixture() -def batch(): +def batch() -> dict[str, torch.Tensor]: + """Populate example batch for testing.""" batch = { "input_ids": torch.zeros((1, 12)), "labels": torch.zeros((1, 12)), @@ -103,20 +130,16 @@ def batch(): } batch = {k: v.type(torch.LongTensor) for k, v in batch.items()} - batch = {k: v.to(torch.device(0)) for k, v in batch.items()} - - return batch + return {k: v.to(torch.device(0)) for k, v in batch.items()} -def test_load_base_model(base_model): +def test_load_base_model(base_model: torch.nn.Module) -> None: + """Ensure no error is encountered when instantiating base model fixture.""" print(base_model) -def test_match_submodule_by_pattern(base_model, lora_model): - """ - Test selecting DecoderLayer class from container. - """ - +def test_match_submodule_by_pattern(base_model: torch.nn.Module) -> None: + """Test selecting DecoderLayer class from container.""" submodule = get_submodule_by_pattern(base_model, r"DecoderLayer$") assert submodule == OPTDecoderLayer @@ -124,43 +147,36 @@ def test_match_submodule_by_pattern(base_model, lora_model): assert submodule == OPTDecoderLayer -def test_partition_base_model( - base_model_sharded, setup_and_teardown_torch_process_group -): - """ - Test partitioning base model (no lora/peft). - """ +@pytest.mark.usefixtures("_setup_and_teardown_torch_process_group") +def test_partition_base_model(base_model_sharded: torch.nn.Module) -> None: + """Test partitioning base model (no lora/peft).""" output_text = [] for parameter_name, parameter in base_model_sharded.named_parameters(): requires_grad = parameter.requires_grad - assert requires_grad == True - output_text.append("{}\t{}".format(requires_grad, parameter_name)) + assert requires_grad + output_text.append(f"{requires_grad}\t{parameter_name}") with open("data/output_base.txt", "w") as output_file: output_file.write("\n".join(output_text)) -def test_get_module_types(lora_model_sharded): - """ - Output type of each module. - """ +def test_get_module_types(lora_model_sharded: torch.nn.Module) -> None: + """Output type of each module.""" output_text = [] print(lora_model_sharded) for module_name, module in lora_model_sharded.named_modules(): - output_text.append("{}\t{}".format(module_name, type(module))) + output_text.append(f"{module_name}\t{type(module)}") with open("data/module_types.txt", "w") as output_file: output_file.write("\n".join(output_text)) +@pytest.mark.usefixtures("_setup_and_teardown_torch_process_group") def test_fsdp_lora_model_require_grad( - lora_model_sharded, setup_and_teardown_torch_process_group -): - """ - Test partitioning lora peft model. - """ - + lora_model_sharded: torch.nn.Module, +) -> None: + """Test partitioning lora peft model.""" requires_grad_counters = defaultdict(Counter) output_text = [] @@ -169,12 +185,12 @@ def test_fsdp_lora_model_require_grad( requires_grad = parameter.requires_grad requires_grad_counters[requires_grad][parameter_name] += 1 if re.search("lora_[A|B]", parameter_name) is not None: - assert requires_grad == True, parameter_name + assert requires_grad, parameter_name else: - assert requires_grad == False, parameter_name + assert not requires_grad, parameter_name output_text.append( - "{}\t{}\t{}".format(requires_grad, parameter.device, parameter_name) + f"{requires_grad}\t{parameter.device}\t{parameter_name}", ) if reference_device is not None: @@ -182,16 +198,15 @@ def test_fsdp_lora_model_require_grad( reference_device = parameter.device - # # Uncomment line below to see all parameter names. - # print(requires_grad_counters) with open("data/output.txt", "w") as output_file: output_file.write("\n".join(output_text)) -def test_forward_base(base_model_sharded, batch): - """ - Test forward run of sharded base model. - """ +def test_forward_base( + base_model_sharded: torch.nn.Module, + batch: dict[str, torch.Tensor], +) -> None: + """Test forward run of sharded base model.""" base_model_sharded.train() output = base_model_sharded(**batch) loss = output.loss @@ -201,10 +216,11 @@ def test_forward_base(base_model_sharded, batch): print(loss.shape) -def test_forward_lora(lora_model_sharded, batch): - """ - Test forward run of sharded lora model. - """ +def test_forward_lora( + lora_model_sharded: torch.nn.Module, + batch: dict[str, torch.Tensor], +) -> None: + """Test forward run of sharded lora model.""" lora_model_sharded.train() output = lora_model_sharded(**batch) loss = output.loss @@ -213,10 +229,11 @@ def test_forward_lora(lora_model_sharded, batch): print(loss.shape) -def test_forward_backward_lora(lora_model_sharded, batch): - """ - Test forward and backward run of sharded lora model. - """ +def test_forward_backward_lora( + lora_model_sharded: torch.nn.Module, + batch: dict[str, torch.Tensor], +) -> None: + """Test forward and backward run of sharded lora model.""" lora_model_sharded.train() output = lora_model_sharded(**batch) loss = output.loss @@ -228,10 +245,12 @@ def test_forward_backward_lora(lora_model_sharded, batch): print(loss.shape) -def test_train_lora(lora_model_sharded, optimizer_lora_sharded, batch): - """ - Test N optimization steps on the LoRA sharded model. - """ +def test_train_lora( + lora_model_sharded: torch.nn.Module, + optimizer_lora_sharded: torch.nn.Module, + batch: dict[str, torch.Tensor], +) -> None: + """Test N optimization steps on the LoRA sharded model.""" optimizer = optimizer_lora_sharded model = lora_model_sharded loss_values = [] diff --git a/vectorlm/trainer.py b/vectorlm/trainer.py index bdd6c73..e0e483f 100644 --- a/vectorlm/trainer.py +++ b/vectorlm/trainer.py @@ -1,17 +1,17 @@ from __future__ import annotations -from contextlib import contextmanager import math import os -from typing import Any +from contextlib import _GeneratorContextManager, contextmanager +from typing import Any, Callable, Generator import torch import torch.distributed as dist -import wandb from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler, ReduceLROnPlateau from transformers import PreTrainedTokenizer +import wandb from vectorlm.dataset import Dataset from vectorlm.utils.data_utils import Config from vectorlm.utils.save_utils import ( @@ -29,12 +29,15 @@ @contextmanager -def timer_placeholder(task_name: str): +def _timer_placeholder( + _: str, + __: dict[str, Any] | None = None, +) -> Generator[None, None, None]: try: yield # start code block finally: # run before exiting - return + pass class Trainer: @@ -66,7 +69,10 @@ def __init__( config: Config, enable_wandb_logging: bool, original_dataset_length: int, - timer_handle=timer_placeholder, + timer_handle: Callable[ + [str, dict[str, Any] | None], + _GeneratorContextManager[None], + ] = _timer_placeholder, ) -> None: """Initialize the Trainer class. @@ -283,7 +289,8 @@ def train_step(self, batch: dict[str, torch.Tensor], epoch: int) -> float: tr_step_loss = out.loss (tr_step_loss / self.gas).backward() torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.config.max_grad_norm + self.model.parameters(), + self.config.max_grad_norm, ) else: # non-fsdp @@ -291,7 +298,8 @@ def train_step(self, batch: dict[str, torch.Tensor], epoch: int) -> float: tr_step_loss = out.loss (tr_step_loss / self.gas).backward() torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.config.max_grad_norm + self.model.parameters(), + self.config.max_grad_norm, ) else: @@ -301,7 +309,8 @@ def train_step(self, batch: dict[str, torch.Tensor], epoch: int) -> float: tr_step_loss = out.loss (tr_step_loss / self.gas).backward() torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.config.max_grad_norm + self.model.parameters(), + self.config.max_grad_norm, ) self.optimizer.step() if isinstance(self.lr_scheduler, ReduceLROnPlateau): @@ -335,10 +344,6 @@ def eval_step(self, epoch: int) -> float: batch["input_ids"] = batch["input_ids"].type(torch.LongTensor) num_tokens = len(batch["input_ids"].flatten()) batch["labels"] = batch["labels"].type(torch.LongTensor) - batch = { - k: v.to(torch.cuda.current_device()) - for k, v in batch.items() - } with self.timer_handle("eval_step", {"num_tokens": num_tokens}): out = self.model(**batch) diff --git a/vectorlm/utils/convert_to_hf.py b/vectorlm/utils/convert_to_hf.py index 4c62811..da7aa3b 100644 --- a/vectorlm/utils/convert_to_hf.py +++ b/vectorlm/utils/convert_to_hf.py @@ -44,6 +44,9 @@ def converter(config: Config) -> None: True, False, 2048, # doesn't matter so hard-coded. + 0, + False, + True, ) model.load_state_dict(state_dict) model.save_pretrained( diff --git a/vectorlm/utils/misc_utils.py b/vectorlm/utils/misc_utils.py index 081deb2..1c59b98 100644 --- a/vectorlm/utils/misc_utils.py +++ b/vectorlm/utils/misc_utils.py @@ -4,8 +4,8 @@ from typing import Any import torch.distributed as dist -import wandb +import wandb from vectorlm.utils.data_utils import Config diff --git a/vectorlm/utils/model_utils.py b/vectorlm/utils/model_utils.py index b4a09e9..8d912da 100644 --- a/vectorlm/utils/model_utils.py +++ b/vectorlm/utils/model_utils.py @@ -2,11 +2,11 @@ import functools import re -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable import torch import torch.distributed as dist -from peft import LoraConfig, PeftConfig, PeftModel, TaskType, get_peft_model +from peft import LoraConfig, PeftModel, TaskType, get_peft_model from peft.utils.other import fsdp_auto_wrap_policy from torch import nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -18,7 +18,6 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import ( FullyShardedDataParallel as FSDP, ) - from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -27,57 +26,35 @@ ) -def _is_bfloat_available() -> bool: - """ - Return whether bfloat is supported for the - current CUDA device. - - Returns: - -------- - bool. True if bfloat is supported. - """ - cuda_capability = torch.cuda.get_device_capability() - cuda_capability_str = "{}.{}".format(*cuda_capability) - if cuda_capability[0] >= 8.0: - print("Hardware capability {}; bfloat is supported".format(cuda_capability_str)) - return True - - else: - print( - "Hardware capability {}; bfloat isn't supported".format(cuda_capability_str) - ) - return False - - def get_half_precision_model(model: nn.Module) -> nn.Module: - """ - Cast model to appropriate half-precision format - depending on GPU hardware support. + """Cast model to appropriate half-precision format. Args: ---- - model: nn.Module to cast. Returns: ------- - nn.Module + """ - model = model.bfloat16() - return model + return model.bfloat16() def get_lora_model_from_base_model( - base_model: nn.Module, peft_config_dict: Dict + base_model: nn.Module, peft_config_dict: dict[str, Any], ) -> PeftModel: - """ - Initialize lora peft configuration from a non-lora model. + """Initialize lora peft configuration from a non-lora model. Args: - ----- + ---- base_model: HuggingFace Transformer model to wrap. peft_config_dict: configuration from yaml config file. + + Returns: + ------- + PeftModel + """ task_type_str = peft_config_dict["task_type"] task_type = getattr(TaskType, task_type_str) @@ -93,61 +70,6 @@ def get_lora_model_from_base_model( return lora_model -def _assert_parameter_shapes(model: nn.Module): - for name, parameter in model.named_parameters(): - print(name, parameter.dtype) - assert parameter.dtype is torch.float16 - - -def load_peft_model_and_tokenizer( - path: str, - use_mp: bool, - use_fa: bool, - max_seq_len: int, - peft_adapter_path: str, - adapter_name: str = "default", - is_trainable: bool = False, - config: PeftConfig | None = None, -) -> tuple[PeftModel, PreTrainedTokenizer]: - """Load a trained PEFT adapter to the base model and return the PeftModel. - - E.g., a base llama-2-13b-chat-hf w/ adapter named nifty - ├── adapters_lora - ├── llama-2-13b-chat-hf+nifty - - Args: - ---- - path: The path where the model and tokenizer are stored. - use_mp: Whether to use mixed-precision. - use_fa: Whether to use Flash Attention 2. - max_seq_len: The maximum sequence length. - peft_adapter_path: path to the adapter model, e.g. - adapters_lora/llama-2-13b-chat-hf+nifty - adapter_name: e.g. nifty - is_trainable: train or inference mode - config: additional configs - - Returns: - ------- - The PEFT model and tokenizer. - - """ - model, tokenizer = load_model_and_tokenizer( - path, - use_mp, - use_fa, - max_seq_len, - ) - peft_model = PeftModel.from_pretrained( - model, - peft_adapter_path, - adapter_name, - is_trainable, - config, - ) - return peft_model, tokenizer - - def load_model_and_tokenizer( path: str, use_mp: bool, @@ -218,7 +140,7 @@ def load_model_and_tokenizer( def fsdp_config( use_mp: bool, - model: nn.Module, + model_to_wrap: nn.Module, strategy: str, local_rank: int, low_cpu_mem_usage: bool, @@ -263,7 +185,7 @@ def _module_init_fn(module: nn.Module) -> Callable: sharding_strategy = getattr(ShardingStrategy, strategy) - ret_dict["auto_wrap_policy"] = fsdp_auto_wrap_policy(model) + ret_dict["auto_wrap_policy"] = fsdp_auto_wrap_policy(model_to_wrap) ret_dict["sharding_strategy"] = sharding_strategy ret_dict["device_id"] = torch.cuda.current_device() if low_cpu_mem_usage: @@ -347,32 +269,36 @@ def hook_activation_checkpointing( def get_submodule_by_pattern( - module: nn.Module, pattern: str -) -> Optional[type[nn.Module]]: - """ - Return the first module.cls that matches pattern, - at least partially. + module: nn.Module, pattern: str, +) -> type[nn.Module] | None: + """Return the first module.cls that matches pattern at least partially. With reference to get_module_class_from_name from HuggingFace accelerate `FullyShardedDataParallelPlugin`. Args: - ----- + ---- module: Layer container pattern: regular expression string. Returns: - -------- - Matched layer (nn.Module) or None if not matched. + ------- + nn.Module: matched layer (nn.Module), + or + None: if not matched. + """ modules_children = list(module.children()) module_name = module.__class__.__name__ if re.search(pattern, module_name) is not None: return module.__class__ - elif len(modules_children) == 0: - return - else: - for child_module in modules_children: - module_class = get_submodule_by_pattern(child_module, pattern) - if module_class is not None: - return module_class + + if len(modules_children) == 0: + return None + + for child_module in modules_children: + module_class = get_submodule_by_pattern(child_module, pattern) + if module_class is not None: + return module_class + + return None diff --git a/vectorlm/utils/optimizer_utils.py b/vectorlm/utils/optimizer_utils.py index 2d99d04..a202479 100644 --- a/vectorlm/utils/optimizer_utils.py +++ b/vectorlm/utils/optimizer_utils.py @@ -7,8 +7,8 @@ from transformers import get_scheduler -class PlateaeuWithWarmup(ReduceLROnPlateau): - """Class to implement plataeu scheduling with warmup. +class PlateauWithWarmup(ReduceLROnPlateau): + """Class to implement plateau scheduling with warmup. Attributes ---------- @@ -47,7 +47,7 @@ def __init__( verbose: bool = False, num_warmup_steps: int = 0, ) -> None: - """Initialize the PlateaeuWithWarmup scheduler class. + """Initialize the PlateauWithWarmup scheduler class. Arguments: --------- @@ -168,8 +168,8 @@ def get_custom_scheduler( The scheduler. """ - if name == "plataeu-with-warmup": - scheduler = PlateaeuWithWarmup(*args, **kwargs) + if name == "plateau-with-warmup": + scheduler = PlateauWithWarmup(*args, **kwargs) # required, otherwise the very first step the optimizer takes is at the # maximum set LR (because we step the scheduler *after* we step the # optimizer. As a result, the optimizer is set to max LR on the first diff --git a/vectorlm/utils/save_utils.py b/vectorlm/utils/save_utils.py index 816e2ec..295e0a0 100644 --- a/vectorlm/utils/save_utils.py +++ b/vectorlm/utils/save_utils.py @@ -75,7 +75,7 @@ def load_metadata( def get_latest_checkpoint_dir(folder_path: str) -> str: - """Find the latest checkpoing directory using regex. + """Find the latest checkpoint directory using regex. Args: ---- @@ -275,4 +275,4 @@ def load_scheduler( print(f"Loading scheduler state from {input_scheduler_file}") state_dict = torch.load(input_scheduler_file) scheduler.load_state_dict(state_dict) - print(f"Scheduler state loaded from {input_scheduler_file}") \ No newline at end of file + print(f"Scheduler state loaded from {input_scheduler_file}") From 78c6faf00275a6125e1a378262a0c954b74ef90a Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 15 Apr 2024 18:53:34 -0400 Subject: [PATCH 45/89] Implemented restoring LoRA train state from filesystem. During training the adapter weights are saved to and loaded from the filesystem. The base model weights are loaded separately. Revised reference to optim_state_dict_to_load in load_optimizer. --- examples/llama_example.py | 22 ++++++++++++++++--- vectorlm/trainer.py | 41 ++++++++++++++++++++++++++++++----- vectorlm/utils/model_utils.py | 22 +++++++++++++++---- vectorlm/utils/save_utils.py | 36 +++++++++++++++++++++++++++++- 4 files changed, 108 insertions(+), 13 deletions(-) diff --git a/examples/llama_example.py b/examples/llama_example.py index 7d684fc..ec1f348 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -23,7 +23,7 @@ shard_model, ) from vectorlm.utils.optimizer_utils import get_custom_scheduler -from vectorlm.utils.save_utils import save_consolidated_model +from vectorlm.utils.save_utils import checkpoint_exists, save_consolidated_model def parse_args() -> Namespace: @@ -76,12 +76,27 @@ def main(config: Config) -> None: ) lora_peft_config = getattr( - config.train_parameters, "lora_peft_config", None, + config.train_parameters, + "lora_peft_config", + None, ) if lora_peft_config is not None: - model = get_lora_model_from_base_model(model, lora_peft_config) + is_peft_adapter_restored = False + peft_adapter_path = None + + # Restore peft adapter from filesystem if available. + if checkpoint_exists(training_args.output_dir): + peft_adapter_path = training_args.output_dir + is_peft_adapter_restored = True + + model = get_lora_model_from_base_model( + model, + lora_peft_config, + peft_adapter_path, + ) decoder_layer_module = get_submodule_by_pattern(model, r"DecoderLayer$") + assert decoder_layer_module is not None, f"No DecoderLayer found in {model}" model = shard_model( model.bfloat16(), decoder_layer_module, @@ -127,6 +142,7 @@ def main(config: Config) -> None: dataset, optimizer, lr_scheduler, + is_peft_adapter_restored, ) # Checkpoint check. Always call before training. diff --git a/vectorlm/trainer.py b/vectorlm/trainer.py index e0e483f..848cb5d 100644 --- a/vectorlm/trainer.py +++ b/vectorlm/trainer.py @@ -5,6 +5,7 @@ from contextlib import _GeneratorContextManager, contextmanager from typing import Any, Callable, Generator +import peft import torch import torch.distributed as dist from torch.optim import Optimizer @@ -24,6 +25,7 @@ save_metadata, save_model, save_optimizer, + save_peft_adapter, save_scheduler, ) @@ -64,6 +66,9 @@ class Trainer: """ + peft_method: str | None = None + is_peft_adapter_restored: bool = False + def __init__( self, config: Config, @@ -105,6 +110,9 @@ def __init__( self.timer_handle = timer_handle self._post_process(original_dataset_length) + if hasattr(self.config, "lora_peft_config"): + self.peft_method = peft.utils.peft_types.PeftType.LORA + def _post_process(self, ds_orig_length: int) -> None: """Calculate steps for weight updates and saving.""" sharded_ds_orig_len = math.ceil( @@ -128,6 +136,7 @@ def prepare_trainer( dataset: Dataset, optimizer: Optimizer, lr_scheduler: LRScheduler | ReduceLROnPlateau, + is_peft_adapter_restored: bool = False, ) -> None: """Set all essential training requirements. @@ -139,6 +148,8 @@ def prepare_trainer( optimizer: The training optimizer. lr_scheduler: The LR scheduler. + is_peft_adapter_restored: whether peft is enabled and + adapters were restored from filesystem. """ self.model = model @@ -147,6 +158,8 @@ def prepare_trainer( self.optimizer = optimizer self.lr_scheduler = lr_scheduler + self.is_peft_adapter_restored = is_peft_adapter_restored + def save_checkpoint(self, epoch: int) -> None: """Save all states. @@ -173,13 +186,20 @@ def save_checkpoint(self, epoch: int) -> None: if rank == 0: save_metadata(save_dir, meta_dict) - with self.timer_handle("trainer_save_model"): - save_model(self.model, save_dir, rank) + with self.timer_handle("trainer_save_model", {}): + # Save adapter only if running LoRA. + # Merging adapters into base weights would require gathering + # all weights, which would incur significant overhead. + print(f"type(self.model): {type(self.model)}") + if self.peft_method is peft.utils.peft_types.PeftType.LORA: + save_peft_adapter(self.model, self.config.output_dir) + else: + save_model(self.model, save_dir, rank) - with self.timer_handle("trainer_save_optimizer"): + with self.timer_handle("trainer_save_optimizer", {}): save_optimizer(self.optimizer, self.model, save_dir, rank) - with self.timer_handle("train_save_scheduler"): + with self.timer_handle("train_save_scheduler", {}): save_scheduler(self.lr_scheduler, save_dir, rank) dist.barrier() @@ -203,7 +223,18 @@ def load_checkpoint(self, checkpoint_dir: str) -> int: self.tr_step = step self.dataset.set_processed_ids(ids) self.dataset.setup_dataloaders() - load_model(self.model, checkpoint_dir, rank) + + if self.peft_method is peft.utils.peft_types.PeftType.LORA: + # The FSDP wrapper is applied to self.model after the LoRA wrapper. + # It is unclear whether peft supports updating the LoRA wrapper + # tensors of a FSDP-wrapped module. Hence, the peft adapter + # is restored when initializing the LoRA wrapper + # before applying the FSDP wrapper, and the is_peft_adapter_restored + # ensures that the adapter is indeed applied. + assert self.is_peft_adapter_restored + else: + load_model(self.model, checkpoint_dir, rank) + load_optimizer(self.optimizer, self.model, checkpoint_dir, rank) load_scheduler(self.lr_scheduler, checkpoint_dir, rank) dist.barrier() diff --git a/vectorlm/utils/model_utils.py b/vectorlm/utils/model_utils.py index 8d912da..5ee0333 100644 --- a/vectorlm/utils/model_utils.py +++ b/vectorlm/utils/model_utils.py @@ -42,7 +42,9 @@ def get_half_precision_model(model: nn.Module) -> nn.Module: def get_lora_model_from_base_model( - base_model: nn.Module, peft_config_dict: dict[str, Any], + base_model: PreTrainedModel, + peft_config_dict: dict[str, Any], + peft_adapter_path: str | None = None, ) -> PeftModel: """Initialize lora peft configuration from a non-lora model. @@ -50,6 +52,8 @@ def get_lora_model_from_base_model( ---- base_model: HuggingFace Transformer model to wrap. peft_config_dict: configuration from yaml config file. + peft_adapter_path: optionally, initialize peft adapters + using tensors loaded from the filesystem. Returns: ------- @@ -62,9 +66,18 @@ def get_lora_model_from_base_model( # See github.com/pytorch/pytorch/pull/102212 base_model.load_state_dict(base_model.state_dict(), assign=True) - lora_model = get_peft_model(base_model, lora_config) - lora_model = get_half_precision_model(lora_model) + if peft_adapter_path is not None: + lora_model = PeftModel.from_pretrained( + base_model, + peft_adapter_path, + is_trainable=True, + ) + print(f"Restored peft_adapter from {peft_adapter_path}.") + else: + lora_model = get_peft_model(base_model, lora_config) + + lora_model = get_half_precision_model(lora_model) assert isinstance(lora_model, PeftModel) lora_model.print_trainable_parameters() return lora_model @@ -269,7 +282,8 @@ def hook_activation_checkpointing( def get_submodule_by_pattern( - module: nn.Module, pattern: str, + module: nn.Module, + pattern: str, ) -> type[nn.Module] | None: """Return the first module.cls that matches pattern at least partially. diff --git a/vectorlm/utils/save_utils.py b/vectorlm/utils/save_utils.py index 295e0a0..d064571 100644 --- a/vectorlm/utils/save_utils.py +++ b/vectorlm/utils/save_utils.py @@ -3,7 +3,9 @@ import os import re +import peft import torch +import torch.distributed as dist from torch import nn from torch.distributed.fsdp import ( FullStateDictConfig, # general model non-sharded, non-flattened params @@ -171,6 +173,38 @@ def save_consolidated_model( torch.save(state_dict, save_path) +def get_peft_adapter_tensor_dict( + model: peft.peft_model.PeftModel, +) -> dict[str, torch.Tensor] | None: + """Return LoRA PEFT Adapter tensor state dict on rank 0. + + Returns None for all other ranks. + """ + with FSDP.state_dict_type( + model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + if dist.get_rank() == 0: + return peft.utils.save_and_load.get_peft_model_state_dict(model) + + return None + + +def save_peft_adapter( + model: peft.peft_model.PeftModel, + output_path: str, +) -> None: + """Save peft adapter to filesystem in a FSDP environment.""" + with FSDP.state_dict_type( + model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + if dist.get_rank() == 0: + model.save_pretrained(output_path) + + def save_optimizer( optimizer: Optimizer, model: nn.Module, @@ -229,7 +263,7 @@ def load_optimizer( ): print(f"Loading optimizer state from {input_optimizer_file}") opt_state = torch.load(input_optimizer_file) - opt_state = FSDP.optim_state_dict_to_load(opt_state, model, optimizer) + opt_state = FSDP.optim_state_dict_to_load(model, optimizer, opt_state) optimizer.load_state_dict(opt_state) print(f"Optimizer state loaded from {input_optimizer_file}") From aea2ed8b6a2c32b5904783206df414858467394e Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 15 Apr 2024 19:09:19 -0400 Subject: [PATCH 46/89] Included train step number in LoRA adapter output path. --- examples/llama_example.py | 14 ++++++++++++-- vectorlm/trainer.py | 2 +- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/examples/llama_example.py b/examples/llama_example.py index ec1f348..bfa0325 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -23,7 +23,11 @@ shard_model, ) from vectorlm.utils.optimizer_utils import get_custom_scheduler -from vectorlm.utils.save_utils import checkpoint_exists, save_consolidated_model +from vectorlm.utils.save_utils import ( + checkpoint_exists, + get_latest_checkpoint_dir, + save_consolidated_model, +) def parse_args() -> Namespace: @@ -86,7 +90,13 @@ def main(config: Config) -> None: # Restore peft adapter from filesystem if available. if checkpoint_exists(training_args.output_dir): - peft_adapter_path = training_args.output_dir + peft_adapter_path = os.path.join( + training_args.output_dir, + "checkpoints", + get_latest_checkpoint_dir( + os.path.join(training_args.output_dir, "checkpoints"), + ), + ) is_peft_adapter_restored = True model = get_lora_model_from_base_model( diff --git a/vectorlm/trainer.py b/vectorlm/trainer.py index 848cb5d..0ca40f3 100644 --- a/vectorlm/trainer.py +++ b/vectorlm/trainer.py @@ -192,7 +192,7 @@ def save_checkpoint(self, epoch: int) -> None: # all weights, which would incur significant overhead. print(f"type(self.model): {type(self.model)}") if self.peft_method is peft.utils.peft_types.PeftType.LORA: - save_peft_adapter(self.model, self.config.output_dir) + save_peft_adapter(self.model, save_dir) else: save_model(self.model, save_dir, rank) From dad6553920982221e8447a0af26e234eaffd88fa Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 15 Apr 2024 22:15:35 -0400 Subject: [PATCH 47/89] Added reference throughput table to documentation. --- docs/reference_throughput.md | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 docs/reference_throughput.md diff --git a/docs/reference_throughput.md b/docs/reference_throughput.md new file mode 100644 index 0000000..0e0cc6b --- /dev/null +++ b/docs/reference_throughput.md @@ -0,0 +1,30 @@ +# Reference Throughput + +We've benchmarked VectorLM on the Vaughan cluster for number of model architectures across a variety of node configurations. +In each experiment, we use a batch size of 8 and the maximum context length that the pre-trained LLM supports, capped at 65536. +In experiments labelled as LoRA, we set hidden dimension to 8. + +Entries that read NaN represent combinations where the node configuration does not have enough GPU memory for the training run to complete. An exception is gemma-2b, which currently does not support full-rank FSDP fine-tuning. + +| | Llama-2-13b-hf | Llama-2-7b-hf | Mistral-7B-v0.1 | Mixtral-8x7B-Instruct-v0.1 | gemma-2b | opt-350m | +|:-------------------------------------|-----------------:|----------------:|------------------:|-----------------------------:|-----------:|-----------:| +| (full_rank) NVIDIA A100-SXM4-80GB x1 | 424.726 | 570.818 | 528.747 | nan | nan | 780.045 | +| (full_rank) NVIDIA A100-SXM4-80GB x2 | 660.355 | 919.19 | 794.566 | 275.459 | nan | 1227.67 | +| (full_rank) NVIDIA A100-SXM4-80GB x4 | 1309.4 | 1744.39 | 1577.09 | 817.162 | nan | 2181.46 | +| (full_rank) NVIDIA A40 x1 | nan | 47.6435 | 107.503 | nan | nan | 666.881 | +| (full_rank) NVIDIA A40 x2 | nan | 313.074 | 322.624 | nan | nan | 854.672 | +| (full_rank) NVIDIA A40 x4 | 345.96 | 570.977 | 553.658 | nan | nan | 1765.49 | +| (full_rank) Tesla T4 x1 | nan | nan | nan | nan | nan | 475.51 | +| (full_rank) Tesla T4 x2 | nan | nan | nan | nan | nan | 768.008 | +| (full_rank) Tesla T4 x4 | nan | nan | nan | nan | nan | 1383.6 | +| (full_rank) Tesla T4 x8 | nan | nan | nan | nan | nan | 2414.68 | +| (lora) NVIDIA A100-SXM4-80GB x1 | 560.167 | 646.801 | 525.802 | nan | 851.678 | 859.379 | +| (lora) NVIDIA A100-SXM4-80GB x2 | 871.993 | 1157.17 | 1105.68 | 239.431 | 1724.57 | 1463.82 | +| (lora) NVIDIA A100-SXM4-80GB x4 | 1783.53 | 2091.03 | 2150.06 | 1309.74 | 2719.24 | 2381.01 | +| (lora) NVIDIA A40 x1 | 272.931 | 435.386 | 336.507 | nan | 983.256 | 652.611 | +| (lora) NVIDIA A40 x2 | 105.442 | 457.183 | 356.263 | nan | 725.723 | 1136.17 | +| (lora) NVIDIA A40 x4 | 543.22 | 715.416 | 642.642 | nan | 1302.62 | 1647.57 | +| (lora) Tesla T4 x1 | nan | nan | nan | nan | 148.272 | 571.471 | +| (lora) Tesla T4 x2 | nan | 101.126 | 102.859 | nan | 256.534 | 811.159 | +| (lora) Tesla T4 x4 | nan | 188.575 | 190.127 | nan | 495.755 | 1506.05 | +| (lora) Tesla T4 x8 | 196.709 | 372.375 | 351.361 | nan | 897.81 | 2945.86 | \ No newline at end of file From bbcda75454553b258ae096bb4624e7a5250f6e8a Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 15 Apr 2024 22:18:13 -0400 Subject: [PATCH 48/89] Added unit description to reference throughput table. Applied markdown formatting via prettier. --- docs/reference_throughput.md | 48 +++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/docs/reference_throughput.md b/docs/reference_throughput.md index 0e0cc6b..f6a0c35 100644 --- a/docs/reference_throughput.md +++ b/docs/reference_throughput.md @@ -1,30 +1,32 @@ # Reference Throughput -We've benchmarked VectorLM on the Vaughan cluster for number of model architectures across a variety of node configurations. +We've benchmarked VectorLM on the Vaughan cluster for number of model architectures across a variety of node configurations. In each experiment, we use a batch size of 8 and the maximum context length that the pre-trained LLM supports, capped at 65536. In experiments labelled as LoRA, we set hidden dimension to 8. Entries that read NaN represent combinations where the node configuration does not have enough GPU memory for the training run to complete. An exception is gemma-2b, which currently does not support full-rank FSDP fine-tuning. -| | Llama-2-13b-hf | Llama-2-7b-hf | Mistral-7B-v0.1 | Mixtral-8x7B-Instruct-v0.1 | gemma-2b | opt-350m | -|:-------------------------------------|-----------------:|----------------:|------------------:|-----------------------------:|-----------:|-----------:| -| (full_rank) NVIDIA A100-SXM4-80GB x1 | 424.726 | 570.818 | 528.747 | nan | nan | 780.045 | -| (full_rank) NVIDIA A100-SXM4-80GB x2 | 660.355 | 919.19 | 794.566 | 275.459 | nan | 1227.67 | -| (full_rank) NVIDIA A100-SXM4-80GB x4 | 1309.4 | 1744.39 | 1577.09 | 817.162 | nan | 2181.46 | -| (full_rank) NVIDIA A40 x1 | nan | 47.6435 | 107.503 | nan | nan | 666.881 | -| (full_rank) NVIDIA A40 x2 | nan | 313.074 | 322.624 | nan | nan | 854.672 | -| (full_rank) NVIDIA A40 x4 | 345.96 | 570.977 | 553.658 | nan | nan | 1765.49 | -| (full_rank) Tesla T4 x1 | nan | nan | nan | nan | nan | 475.51 | -| (full_rank) Tesla T4 x2 | nan | nan | nan | nan | nan | 768.008 | -| (full_rank) Tesla T4 x4 | nan | nan | nan | nan | nan | 1383.6 | -| (full_rank) Tesla T4 x8 | nan | nan | nan | nan | nan | 2414.68 | -| (lora) NVIDIA A100-SXM4-80GB x1 | 560.167 | 646.801 | 525.802 | nan | 851.678 | 859.379 | -| (lora) NVIDIA A100-SXM4-80GB x2 | 871.993 | 1157.17 | 1105.68 | 239.431 | 1724.57 | 1463.82 | -| (lora) NVIDIA A100-SXM4-80GB x4 | 1783.53 | 2091.03 | 2150.06 | 1309.74 | 2719.24 | 2381.01 | -| (lora) NVIDIA A40 x1 | 272.931 | 435.386 | 336.507 | nan | 983.256 | 652.611 | -| (lora) NVIDIA A40 x2 | 105.442 | 457.183 | 356.263 | nan | 725.723 | 1136.17 | -| (lora) NVIDIA A40 x4 | 543.22 | 715.416 | 642.642 | nan | 1302.62 | 1647.57 | -| (lora) Tesla T4 x1 | nan | nan | nan | nan | 148.272 | 571.471 | -| (lora) Tesla T4 x2 | nan | 101.126 | 102.859 | nan | 256.534 | 811.159 | -| (lora) Tesla T4 x4 | nan | 188.575 | 190.127 | nan | 495.755 | 1506.05 | -| (lora) Tesla T4 x8 | 196.709 | 372.375 | 351.361 | nan | 897.81 | 2945.86 | \ No newline at end of file +All values in the table below represent the overall training throughput in tokens per second, aggregated across all GPU devices. + +| | Llama-2-13b-hf | Llama-2-7b-hf | Mistral-7B-v0.1 | Mixtral-8x7B-Instruct-v0.1 | gemma-2b | opt-350m | +| :----------------------------------- | -------------: | ------------: | --------------: | -------------------------: | -------: | -------: | +| (full_rank) NVIDIA A100-SXM4-80GB x1 | 424.726 | 570.818 | 528.747 | nan | nan | 780.045 | +| (full_rank) NVIDIA A100-SXM4-80GB x2 | 660.355 | 919.19 | 794.566 | 275.459 | nan | 1227.67 | +| (full_rank) NVIDIA A100-SXM4-80GB x4 | 1309.4 | 1744.39 | 1577.09 | 817.162 | nan | 2181.46 | +| (full_rank) NVIDIA A40 x1 | nan | 47.6435 | 107.503 | nan | nan | 666.881 | +| (full_rank) NVIDIA A40 x2 | nan | 313.074 | 322.624 | nan | nan | 854.672 | +| (full_rank) NVIDIA A40 x4 | 345.96 | 570.977 | 553.658 | nan | nan | 1765.49 | +| (full_rank) Tesla T4 x1 | nan | nan | nan | nan | nan | 475.51 | +| (full_rank) Tesla T4 x2 | nan | nan | nan | nan | nan | 768.008 | +| (full_rank) Tesla T4 x4 | nan | nan | nan | nan | nan | 1383.6 | +| (full_rank) Tesla T4 x8 | nan | nan | nan | nan | nan | 2414.68 | +| (lora) NVIDIA A100-SXM4-80GB x1 | 560.167 | 646.801 | 525.802 | nan | 851.678 | 859.379 | +| (lora) NVIDIA A100-SXM4-80GB x2 | 871.993 | 1157.17 | 1105.68 | 239.431 | 1724.57 | 1463.82 | +| (lora) NVIDIA A100-SXM4-80GB x4 | 1783.53 | 2091.03 | 2150.06 | 1309.74 | 2719.24 | 2381.01 | +| (lora) NVIDIA A40 x1 | 272.931 | 435.386 | 336.507 | nan | 983.256 | 652.611 | +| (lora) NVIDIA A40 x2 | 105.442 | 457.183 | 356.263 | nan | 725.723 | 1136.17 | +| (lora) NVIDIA A40 x4 | 543.22 | 715.416 | 642.642 | nan | 1302.62 | 1647.57 | +| (lora) Tesla T4 x1 | nan | nan | nan | nan | 148.272 | 571.471 | +| (lora) Tesla T4 x2 | nan | 101.126 | 102.859 | nan | 256.534 | 811.159 | +| (lora) Tesla T4 x4 | nan | 188.575 | 190.127 | nan | 495.755 | 1506.05 | +| (lora) Tesla T4 x8 | 196.709 | 372.375 | 351.361 | nan | 897.81 | 2945.86 | From d397488b2f71e730d65189025ecd741ce9c4e252 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 15 Apr 2024 22:19:24 -0400 Subject: [PATCH 49/89] Added unit description to reference throughput table. Applied markdown formatting via prettier. --- docs/reference_throughput.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/reference_throughput.md b/docs/reference_throughput.md index f6a0c35..f17e739 100644 --- a/docs/reference_throughput.md +++ b/docs/reference_throughput.md @@ -6,7 +6,7 @@ In experiments labelled as LoRA, we set hidden dimension to 8. Entries that read NaN represent combinations where the node configuration does not have enough GPU memory for the training run to complete. An exception is gemma-2b, which currently does not support full-rank FSDP fine-tuning. -All values in the table below represent the overall training throughput in tokens per second, aggregated across all GPU devices. +All values in the table below represent the median training throughput in tokens per second across all training steps, aggregated across all GPU devices. | | Llama-2-13b-hf | Llama-2-7b-hf | Mistral-7B-v0.1 | Mixtral-8x7B-Instruct-v0.1 | gemma-2b | opt-350m | | :----------------------------------- | -------------: | ------------: | --------------: | -------------------------: | -------: | -------: | From 35b97b8d5b33aa902e5619b3b7279b5014250301 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 15 Apr 2024 22:21:58 -0400 Subject: [PATCH 50/89] Benchmark: added option to override max_length of pre-trained model. --- profiling/benchmark.py | 6 +++++- profiling/launch_benchmark.py | 5 ++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/profiling/benchmark.py b/profiling/benchmark.py index e169e53..23548f3 100644 --- a/profiling/benchmark.py +++ b/profiling/benchmark.py @@ -67,6 +67,7 @@ def parse_args() -> Namespace: "--num_eval_examples", default=1000, ) + parser.add_argument("--max_length", type=int) return parser.parse_args() @@ -240,7 +241,7 @@ def __init__( self.num_train_examples = num_train_examples self.num_eval_examples = num_eval_examples - if max_length is not None: + if (max_length is not None) and (max_length < 0): self.max_length = max_length else: self.max_length = min(tokenizer.model_max_length, _MAX_SEQ_LENGTH) @@ -358,8 +359,11 @@ def load_datasets(self) -> None: num_train_examples=args.num_train_examples, num_eval_examples=args.num_eval_examples, tokenizer=tokenizer, + max_length=args.max_length, ) + write_metrics("max_length", dataset.max_length) + # instantiate trainer trainer = Trainer( config=training_args, diff --git a/profiling/launch_benchmark.py b/profiling/launch_benchmark.py index 9374350..acb12e0 100644 --- a/profiling/launch_benchmark.py +++ b/profiling/launch_benchmark.py @@ -36,6 +36,8 @@ "profiling/configs/benchmark.yaml", ] +max_length_list = [1024, 2048, 4096, -1] + slurm_flags_options = { "nodes": [1], "mem-per-gpu": ["16GB"], @@ -52,6 +54,7 @@ ["profiling/launch_benchmark.sh"], config_list, model_list, + max_length_list, ] timestamp = int(time.time()) @@ -80,7 +83,7 @@ arg = (f"--{key}", str(value)) args.extend(arg) - args.extend(pos_args_option) + args.extend(str(arg) for arg in pos_args_option) args_list.append(args) print(" ".join(args)) From 6af7791e45465c32388fc96a67586fdb08194746 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 15 Apr 2024 22:22:58 -0400 Subject: [PATCH 51/89] Deleted unused `accelerate` dependency from requirements.txt --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 04f0da2..bdf0065 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -accelerate datasets transformers sentencepiece From 97be477eab7acbadab5cac3aabfa9a2312a04dd5 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 15 Apr 2024 22:23:53 -0400 Subject: [PATCH 52/89] Benchmark: added comment on max_length. --- profiling/launch_benchmark.py | 1 + 1 file changed, 1 insertion(+) diff --git a/profiling/launch_benchmark.py b/profiling/launch_benchmark.py index acb12e0..be2263d 100644 --- a/profiling/launch_benchmark.py +++ b/profiling/launch_benchmark.py @@ -36,6 +36,7 @@ "profiling/configs/benchmark.yaml", ] +# Set to (-1) to fall back to the max context length of the pre-trained model. max_length_list = [1024, 2048, 4096, -1] slurm_flags_options = { From b43e5650cb759d7cf10b7be83cf1b9c6702922c3 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 15 Apr 2024 22:31:19 -0400 Subject: [PATCH 53/89] Benchmark: added comment on batch size. --- docs/reference_throughput.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/reference_throughput.md b/docs/reference_throughput.md index f17e739..4593c33 100644 --- a/docs/reference_throughput.md +++ b/docs/reference_throughput.md @@ -1,9 +1,10 @@ # Reference Throughput We've benchmarked VectorLM on the Vaughan cluster for number of model architectures across a variety of node configurations. -In each experiment, we use a batch size of 8 and the maximum context length that the pre-trained LLM supports, capped at 65536. In experiments labelled as LoRA, we set hidden dimension to 8. +For consistency, we use a batch size of 8 and the maximum context length that the pre-trained LLM supports, capped at 65536. Especially for smaller models, it might be possible to achieve a higher throughput by increasing the batch size. + Entries that read NaN represent combinations where the node configuration does not have enough GPU memory for the training run to complete. An exception is gemma-2b, which currently does not support full-rank FSDP fine-tuning. All values in the table below represent the median training throughput in tokens per second across all training steps, aggregated across all GPU devices. From 607de70eb1d3dbc3c2da8a1a3f693156daf457c8 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Tue, 16 Apr 2024 08:07:44 -0400 Subject: [PATCH 54/89] Benchmark: added option to override batch size. --- profiling/benchmark.py | 7 ++++++- profiling/launch_benchmark.py | 2 ++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/profiling/benchmark.py b/profiling/benchmark.py index 23548f3..b5300af 100644 --- a/profiling/benchmark.py +++ b/profiling/benchmark.py @@ -68,6 +68,7 @@ def parse_args() -> Namespace: default=1000, ) parser.add_argument("--max_length", type=int) + parser.add_argument("--training_batch_size", type=int) return parser.parse_args() @@ -241,7 +242,7 @@ def __init__( self.num_train_examples = num_train_examples self.num_eval_examples = num_eval_examples - if (max_length is not None) and (max_length < 0): + if (max_length is not None) and (max_length > 0): self.max_length = max_length else: self.max_length = min(tokenizer.model_max_length, _MAX_SEQ_LENGTH) @@ -271,6 +272,10 @@ def load_datasets(self) -> None: config = Config(yaml_path=args.yaml_path) setup(config.train_parameters.output_dir) + if args.training_batch_size is not None: + config.dataset.train_bs = args.training_batch_size + write_metrics("training_batch_size", args.training_batch_size) + print(f"Writing metrics to {output_path}") write_metrics("model_name", args.model_name) write_metrics("config", {**config.__dict__}) diff --git a/profiling/launch_benchmark.py b/profiling/launch_benchmark.py index be2263d..e9509ed 100644 --- a/profiling/launch_benchmark.py +++ b/profiling/launch_benchmark.py @@ -38,6 +38,7 @@ # Set to (-1) to fall back to the max context length of the pre-trained model. max_length_list = [1024, 2048, 4096, -1] +batch_size = [8, 16, 32, 64, 128] slurm_flags_options = { "nodes": [1], @@ -56,6 +57,7 @@ config_list, model_list, max_length_list, + batch_size, ] timestamp = int(time.time()) From bdef48f31ba2d14b13772cca3052908549d2177a Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Tue, 16 Apr 2024 08:22:32 -0400 Subject: [PATCH 55/89] Benchmark throughput documentation: revised word choices. --- docs/reference_throughput.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/reference_throughput.md b/docs/reference_throughput.md index 4593c33..c725bb5 100644 --- a/docs/reference_throughput.md +++ b/docs/reference_throughput.md @@ -1,9 +1,9 @@ # Reference Throughput -We've benchmarked VectorLM on the Vaughan cluster for number of model architectures across a variety of node configurations. +We've benchmarked VectorLM on the Vaughan cluster for a number of model architectures across a variety of node configurations. In experiments labelled as LoRA, we set hidden dimension to 8. -For consistency, we use a batch size of 8 and the maximum context length that the pre-trained LLM supports, capped at 65536. Especially for smaller models, it might be possible to achieve a higher throughput by increasing the batch size. +For consistency, we use a batch size of 8 and the maximum context length that the pre-trained LLM supports, capped at 65536. Note that especially for smaller models, it might be possible to further increase throughput by switching to a larger batch size. Entries that read NaN represent combinations where the node configuration does not have enough GPU memory for the training run to complete. An exception is gemma-2b, which currently does not support full-rank FSDP fine-tuning. From 3294a399de9e99fe1f6683b48f4ec3b3feed6447 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Tue, 16 Apr 2024 11:48:43 -0400 Subject: [PATCH 56/89] LoRA Hot-Swap: Implemented vLLM integration test scaffolding and PyTest fixtures. --- vectorlm/tests/test_vllm.py | 173 ++++++++++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 vectorlm/tests/test_vllm.py diff --git a/vectorlm/tests/test_vllm.py b/vectorlm/tests/test_vllm.py new file mode 100644 index 0000000..17e2a48 --- /dev/null +++ b/vectorlm/tests/test_vllm.py @@ -0,0 +1,173 @@ +"""vLLM Integration tests. + +LLM for scaffolding: +- gemma-2b (vLLM supports LoRA for Gemma but not OPT) + +Scaffolding and fixtures: +- vLLM + - Spin up vLLM Engine via Python API + - LoRA request via peft LoRA adapter from + - regular disk. + - folder saved in /dev/shm. +- References top-k log probabilities via vLLM + - Base model + - LoRA model loaded from /dev/shm + - LoRA model loaded from disk +""" + +from __future__ import annotations + +import numpy as np +import pytest +import vllm +import vllm.sequence +from vllm.lora.request import LoRARequest + +BASE_MODEL_PATH = "/model-weights/gemma-2b" +LORA_ADAPTER_PATH = "data/example-adapters/gemma-2b-gsm8k" +NUM_TOP_LOGPROBS = 5 + + +@pytest.fixture(scope="session") +def vllm_model() -> vllm.LLM: + """Spin up vLLM base model.""" + return vllm.LLM( + BASE_MODEL_PATH, + gpu_memory_utilization=0.3, + enable_lora=True, + ) + + +@pytest.fixture(scope="session") +def vllm_sampling_params() -> vllm.SamplingParams: + """Return example vLLM sampling parameters for consistency across tests.""" + return vllm.SamplingParams(logprobs=NUM_TOP_LOGPROBS, temperature=0, seed=1) + + +@pytest.fixture(scope="session") +def example_prompts() -> list[str]: + """Return example prompts.""" + return [ + "Vector Institute is located in", + "The answer to life the universe and everything is of course", + "Vector Institute is located in", + ] + + +def extract_logprobs( + vllm_responses: list[vllm.RequestOutput], +) -> list[list[vllm.sequence.SampleLogprobs]]: + """Extract logprobs from vllm response. + + Additionally, ensures none of these output is None. + + Params + ------ + vllm_responses: output from LLM.generate() + + Returns + ------- + Nested list, one list of logprobs instance for each prompt. + + """ + logprobs_responses: list[list[vllm.sequence.SampleLogprobs]] = [] + for response in vllm_responses: + for output in response.outputs: + logprobs_options: list[vllm.sequence.SampleLogprobs] = [] + logprobs = output.logprobs + assert logprobs is not None + logprobs_options.append(logprobs) + + logprobs_responses.append(logprobs_options) + + return logprobs_responses + + +def assert_logprobs_allclose( + logprobs_a: vllm.sequence.SampleLogprobs, + logprobs_b: vllm.sequence.SampleLogprobs, +) -> None: + """Ensure that logprobs_a are all close with logprobs_b.""" + assert len(logprobs_a) == len(logprobs_b) + for token_logprobs_a, token_logprobs_b in zip(logprobs_a, logprobs_b): + assert token_logprobs_a.keys() == token_logprobs_b.keys() + token_logprobs_a_array = np.asarray( + [token_logprobs_a[k].logprob for k in token_logprobs_a], + ) + token_logprobs_b_array = np.asarray( + [token_logprobs_b[k].logprob for k in token_logprobs_a], + ) + assert np.allclose( + np.asarray(token_logprobs_a_array), + np.asarray(token_logprobs_b_array), + ) + + +@pytest.fixture(scope="session") +def base_llm_logprobs( + vllm_model: vllm.LLM, + example_prompts: list[str], + vllm_sampling_params: vllm.SamplingParams, +) -> list[list[vllm.sequence.SampleLogprobs]]: + """Return logprobs for base LLM (no LoRA adapter).""" + vllm_responses = vllm_model.generate(example_prompts, vllm_sampling_params) + return extract_logprobs(vllm_responses) + + +@pytest.fixture(scope="session") +def lora_request() -> LoRARequest: + """Return LoRARequest for vLLM LoRA requests.""" + return LoRARequest("example_adapter", 1, LORA_ADAPTER_PATH) + + +@pytest.fixture(scope="session") +def lora_llm_logprobs( + vllm_model: vllm.LLM, + example_prompts: list[str], + vllm_sampling_params: vllm.SamplingParams, + lora_request: LoRARequest, +) -> list[list[vllm.sequence.SampleLogprobs]]: + """Return logprobs for LoRA-adapted LLM.""" + vllm_responses = vllm_model.generate( + example_prompts, + vllm_sampling_params, + lora_request=lora_request, + ) + return extract_logprobs(vllm_responses) + + +@pytest.mark.parametrize( + "logprobs_fixture_name", + ["base_llm_logprobs", "lora_llm_logprobs"], +) +def test_get_logprobs( + logprobs_fixture_name: str, + request: pytest.FixtureRequest, +) -> None: + """Test obtaining logprobs from base vLLM model.""" + output_logprobs: list[list[vllm.sequence.SampleLogprobs]] = ( + request.getfixturevalue(logprobs_fixture_name) + ) + assert_logprobs_allclose(output_logprobs[0][0], output_logprobs[2][0]) + + with pytest.raises(AssertionError): + assert_logprobs_allclose( + output_logprobs[2][0], + output_logprobs[1][0], + ) + + +def test_compare_ref_logprobs( + base_llm_logprobs: list[list[vllm.sequence.SampleLogprobs]], + lora_llm_logprobs: list[list[vllm.sequence.SampleLogprobs]], +) -> None: + """Ensure base_llm_logprobs are different from lora_llm_logprobs.""" + for base_llm_seq_logprobs, lora_llm_seq_logprobs in zip( + base_llm_logprobs, + lora_llm_logprobs, + ): + with pytest.raises(AssertionError): + assert_logprobs_allclose( + base_llm_seq_logprobs[0], + lora_llm_seq_logprobs[0], + ) From 2bb7bad55cc62e6f76f07f89fd1efc691d973495 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Tue, 16 Apr 2024 18:21:32 -0400 Subject: [PATCH 57/89] LoRA Hot-Swap: Implemented vLLM LoRA hot-swap integration proof-of-concept. --- vectorlm/tests/test_vllm.py | 189 +++++++++++++++++++++++++++++------- 1 file changed, 156 insertions(+), 33 deletions(-) diff --git a/vectorlm/tests/test_vllm.py b/vectorlm/tests/test_vllm.py index 17e2a48..7eb5bb9 100644 --- a/vectorlm/tests/test_vllm.py +++ b/vectorlm/tests/test_vllm.py @@ -17,17 +17,69 @@ from __future__ import annotations +import os.path +import shutil +from typing import Generator + import numpy as np import pytest import vllm import vllm.sequence +from huggingface_hub import snapshot_download from vllm.lora.request import LoRARequest BASE_MODEL_PATH = "/model-weights/gemma-2b" -LORA_ADAPTER_PATH = "data/example-adapters/gemma-2b-gsm8k" +LORA_ADAPTER_HF_HUB_REPO = "jacobthebanana/example-gemma-2b-lora-gsm8k" +LORA_ADAPTER_LOCAL_FOLDER = "data/example_lora_adapter" NUM_TOP_LOGPROBS = 5 +@pytest.fixture(scope="session") +def lora_adapter_path() -> str: + """Download example LoRA adapters from HuggingFace hub. + + Returns + ------- + Path to the adapters on local filesystem. + + """ + if not os.path.exists(f"{LORA_ADAPTER_HF_HUB_REPO}"): + snapshot_download( + LORA_ADAPTER_HF_HUB_REPO, + local_dir=LORA_ADAPTER_LOCAL_FOLDER, + ) + + return LORA_ADAPTER_LOCAL_FOLDER + + +@pytest.fixture(scope="session") +def lora_adapter_path_dev_shm( + lora_adapter_path: str, +) -> Generator[str, None, None]: + """Create a copy of LoRA adapters on /dev/shm. + + Returns + ------- + Path to adapters on the /dev/shm filesystem. + + """ + # Specifically require /dev/shm since /tmp might go to an actual disk, + # incurring overhead and unnecessary SSD wear. + lora_adapter_dev_shm_path = f"/dev/shm/{LORA_ADAPTER_HF_HUB_REPO}" # noqa: S108 + os.makedirs(lora_adapter_dev_shm_path, exist_ok=True) + shutil.copytree( + lora_adapter_path, + lora_adapter_dev_shm_path, + dirs_exist_ok=True, + ) + print(f"Copy: {lora_adapter_path}, {lora_adapter_dev_shm_path}") + + yield lora_adapter_dev_shm_path + + # Clean up to free memory. + shutil.rmtree(lora_adapter_dev_shm_path) + + @pytest.fixture(scope="session") def vllm_model() -> vllm.LLM: """Spin up vLLM base model.""" @@ -41,7 +93,11 @@ def vllm_model() -> vllm.LLM: @pytest.fixture(scope="session") def vllm_sampling_params() -> vllm.SamplingParams: """Return example vLLM sampling parameters for consistency across tests.""" - return vllm.SamplingParams(logprobs=NUM_TOP_LOGPROBS, temperature=0, seed=1) + return vllm.SamplingParams( + logprobs=NUM_TOP_LOGPROBS, + temperature=0.5, + seed=1, + ) @pytest.fixture(scope="session") @@ -49,7 +105,7 @@ def example_prompts() -> list[str]: """Return example prompts.""" return [ "Vector Institute is located in", - "The answer to life the universe and everything is of course", + "The answer to life the universe and everything is ", "Vector Institute is located in", ] @@ -114,20 +170,16 @@ def base_llm_logprobs( return extract_logprobs(vllm_responses) -@pytest.fixture(scope="session") -def lora_request() -> LoRARequest: - """Return LoRARequest for vLLM LoRA requests.""" - return LoRARequest("example_adapter", 1, LORA_ADAPTER_PATH) - - -@pytest.fixture(scope="session") -def lora_llm_logprobs( +def get_lora_llm_logprobs( vllm_model: vllm.LLM, example_prompts: list[str], vllm_sampling_params: vllm.SamplingParams, - lora_request: LoRARequest, + _lora_adapter_path_fixture_name: str, + request: pytest.FixtureRequest, ) -> list[list[vllm.sequence.SampleLogprobs]]: """Return logprobs for LoRA-adapted LLM.""" + lora_adapter_path = request.getfixturevalue(_lora_adapter_path_fixture_name) + lora_request = LoRARequest("example_adapter", 1, lora_adapter_path) vllm_responses = vllm_model.generate( example_prompts, vllm_sampling_params, @@ -136,38 +188,109 @@ def lora_llm_logprobs( return extract_logprobs(vllm_responses) +@pytest.fixture(scope="session") +def lora_llm_logprobs_local_and_dev_shm( + vllm_model: vllm.LLM, + example_prompts: list[str], + vllm_sampling_params: vllm.SamplingParams, + request: pytest.FixtureRequest, +) -> tuple[list[list[vllm.sequence.SampleLogprobs]], ...]: + """Return logprobs via LoRA adapter loaded locally and from /dev/shm. + + Returns + ------- + Two list of lists (options) of vLLM logprobs. + local_adapter_logprobs, dev_shm_adapter_logprobs + + """ + return tuple( + get_lora_llm_logprobs( + vllm_model, + example_prompts, + vllm_sampling_params, + _adapter_path, + request, + ) + for _adapter_path in [ + "lora_adapter_path", + "lora_adapter_path_dev_shm", + ] + ) + + +# Reuse this test case definition for both base and LoRA logprobs. @pytest.mark.parametrize( "logprobs_fixture_name", - ["base_llm_logprobs", "lora_llm_logprobs"], + ["base_llm_logprobs", "lora_llm_logprobs_local_and_dev_shm"], ) -def test_get_logprobs( +def test_logprobs_consistency( logprobs_fixture_name: str, request: pytest.FixtureRequest, ) -> None: - """Test obtaining logprobs from base vLLM model.""" - output_logprobs: list[list[vllm.sequence.SampleLogprobs]] = ( - request.getfixturevalue(logprobs_fixture_name) - ) - assert_logprobs_allclose(output_logprobs[0][0], output_logprobs[2][0]) + """Verify consistency of logprobs from base vLLM model. - with pytest.raises(AssertionError): - assert_logprobs_allclose( - output_logprobs[2][0], - output_logprobs[1][0], - ) + Since vLLM seed is fixed, the same prompt should produce + the same logprobs. + """ + _logprobs_fixture_value: ( + list[list[vllm.sequence.SampleLogprobs]] + | tuple[list[list[vllm.sequence.SampleLogprobs]], ...] + ) = request.getfixturevalue(logprobs_fixture_name) + + if isinstance(_logprobs_fixture_value, tuple): + # A number of logprobs were returned + # (e.g., one for local LoRA, one for /dev/shm) + output_logprobs = _logprobs_fixture_value + else: + output_logprobs = [_logprobs_fixture_value] + + for logprobs in output_logprobs: + assert_logprobs_allclose(logprobs[0][0], logprobs[2][0]) + + with pytest.raises(AssertionError): + assert_logprobs_allclose(logprobs[2][0], logprobs[1][0]) def test_compare_ref_logprobs( base_llm_logprobs: list[list[vllm.sequence.SampleLogprobs]], - lora_llm_logprobs: list[list[vllm.sequence.SampleLogprobs]], + lora_llm_logprobs_local_and_dev_shm: tuple[ + list[list[vllm.sequence.SampleLogprobs]], + ..., + ], ) -> None: """Ensure base_llm_logprobs are different from lora_llm_logprobs.""" - for base_llm_seq_logprobs, lora_llm_seq_logprobs in zip( - base_llm_logprobs, - lora_llm_logprobs, + # Test both lora_adapter options: disk and ram-disk + for lora_llm_logprobs in lora_llm_logprobs_local_and_dev_shm: + for base_llm_seq_logprobs, lora_llm_seq_logprobs in zip( + base_llm_logprobs, + lora_llm_logprobs, + ): + with pytest.raises(AssertionError): + assert_logprobs_allclose( + base_llm_seq_logprobs[0], + lora_llm_seq_logprobs[0], + ) + + +def test_compare_lora_logprobs( + lora_llm_logprobs_local_and_dev_shm: tuple[ + list[list[vllm.sequence.SampleLogprobs]], + ..., + ], +) -> None: + """Ensure LoRA logprobs from local and ram-disk adapters match.""" + for logprobs_local_seq, logprobs_dev_shm_seq in zip( + *lora_llm_logprobs_local_and_dev_shm, ): - with pytest.raises(AssertionError): - assert_logprobs_allclose( - base_llm_seq_logprobs[0], - lora_llm_seq_logprobs[0], - ) + # Each of these represents the logprobs options for a sequence output. + logprobs_local_seq: list[vllm.sequence.SampleLogprobs] + logprobs_dev_shm_seq: list[vllm.sequence.SampleLogprobs] + + assert_logprobs_allclose(logprobs_local_seq[0], logprobs_dev_shm_seq[0]) + sequence_tokens = "".join( + [ + str(next(iter(token.values())).decoded_token) + for token in logprobs_local_seq[0] + ], + ) + print(f"\nVerified equivalence: {sequence_tokens}") From 5d93afe1903e1e1adea5fdeabee1c5357ec3dc4c Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Tue, 16 Apr 2024 18:29:05 -0400 Subject: [PATCH 58/89] LoRA Hot-Swap: added additional fixtures to enhance readability. --- vectorlm/tests/test_vllm.py | 55 ++++++++++++++++++++++++------------- 1 file changed, 36 insertions(+), 19 deletions(-) diff --git a/vectorlm/tests/test_vllm.py b/vectorlm/tests/test_vllm.py index 7eb5bb9..debecce 100644 --- a/vectorlm/tests/test_vllm.py +++ b/vectorlm/tests/test_vllm.py @@ -218,10 +218,36 @@ def lora_llm_logprobs_local_and_dev_shm( ) -# Reuse this test case definition for both base and LoRA logprobs. +@pytest.fixture(scope="session") +def lora_llm_logprobs_local( + lora_llm_logprobs_local_and_dev_shm: tuple[ + list[list[vllm.sequence.SampleLogprobs]], + ..., + ], +) -> list[list[vllm.sequence.SampleLogprobs]]: + """Return logprobs from locally-loaded LoRA adapters.""" + return lora_llm_logprobs_local_and_dev_shm[0] + + +@pytest.fixture(scope="session") +def lora_llm_logprobs_dev_shm( + lora_llm_logprobs_local_and_dev_shm: tuple[ + list[list[vllm.sequence.SampleLogprobs]], + ..., + ], +) -> list[list[vllm.sequence.SampleLogprobs]]: + """Return logprobs from LoRA adapters loaded via /dev/shm ram-disk.""" + return lora_llm_logprobs_local_and_dev_shm[1] + + +# Reuse this test case definition across base and LoRA logprobs. @pytest.mark.parametrize( "logprobs_fixture_name", - ["base_llm_logprobs", "lora_llm_logprobs_local_and_dev_shm"], + [ + "base_llm_logprobs", + "lora_llm_logprobs_local", + "lora_llm_logprobs_dev_shm", + ], ) def test_logprobs_consistency( logprobs_fixture_name: str, @@ -232,23 +258,14 @@ def test_logprobs_consistency( Since vLLM seed is fixed, the same prompt should produce the same logprobs. """ - _logprobs_fixture_value: ( - list[list[vllm.sequence.SampleLogprobs]] - | tuple[list[list[vllm.sequence.SampleLogprobs]], ...] - ) = request.getfixturevalue(logprobs_fixture_name) - - if isinstance(_logprobs_fixture_value, tuple): - # A number of logprobs were returned - # (e.g., one for local LoRA, one for /dev/shm) - output_logprobs = _logprobs_fixture_value - else: - output_logprobs = [_logprobs_fixture_value] - - for logprobs in output_logprobs: - assert_logprobs_allclose(logprobs[0][0], logprobs[2][0]) - - with pytest.raises(AssertionError): - assert_logprobs_allclose(logprobs[2][0], logprobs[1][0]) + logprobs: list[list[vllm.sequence.SampleLogprobs]] = ( + request.getfixturevalue(logprobs_fixture_name) + ) + + assert_logprobs_allclose(logprobs[0][0], logprobs[2][0]) + + with pytest.raises(AssertionError): + assert_logprobs_allclose(logprobs[2][0], logprobs[1][0]) def test_compare_ref_logprobs( From 02988a57b507905f04744763f7662bc92164807d Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Tue, 16 Apr 2024 22:36:15 -0400 Subject: [PATCH 59/89] LoRA Hot-Swap: Deleted redundant np.asarray call in integration test utils. Rephrased comments related to ramdisk. --- vectorlm/tests/test_vllm.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vectorlm/tests/test_vllm.py b/vectorlm/tests/test_vllm.py index debecce..ac3c66a 100644 --- a/vectorlm/tests/test_vllm.py +++ b/vectorlm/tests/test_vllm.py @@ -56,11 +56,11 @@ def lora_adapter_path() -> str: def lora_adapter_path_dev_shm( lora_adapter_path: str, ) -> Generator[str, None, None]: - """Create a copy of LoRA adapters on /dev/shm. + """Create a copy of LoRA adapters within /dev/shm. Returns ------- - Path to adapters on the /dev/shm filesystem. + Path to adapters in the /dev/shm filesystem. """ # Specifically require /dev/shm since /tmp might go to an actual disk, @@ -153,10 +153,7 @@ def assert_logprobs_allclose( token_logprobs_b_array = np.asarray( [token_logprobs_b[k].logprob for k in token_logprobs_a], ) - assert np.allclose( - np.asarray(token_logprobs_a_array), - np.asarray(token_logprobs_b_array), - ) + assert np.allclose(token_logprobs_a_array, token_logprobs_b_array) @pytest.fixture(scope="session") From 5ad5d90ea1cc8a93e0cafe8a71da84379366366e Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Tue, 16 Apr 2024 22:39:32 -0400 Subject: [PATCH 60/89] LoRA Hot-Swap: Updated test case documentations to reflect code reuse in integration test. --- vectorlm/tests/test_vllm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vectorlm/tests/test_vllm.py b/vectorlm/tests/test_vllm.py index ac3c66a..8f61ed1 100644 --- a/vectorlm/tests/test_vllm.py +++ b/vectorlm/tests/test_vllm.py @@ -250,7 +250,7 @@ def test_logprobs_consistency( logprobs_fixture_name: str, request: pytest.FixtureRequest, ) -> None: - """Verify consistency of logprobs from base vLLM model. + """Verify consistency of logprobs. Since vLLM seed is fixed, the same prompt should produce the same logprobs. From afb321ca1db89fdc3c5abb5bdfce82184dbf13c8 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Tue, 16 Apr 2024 23:14:45 -0400 Subject: [PATCH 61/89] Moved profiling-tracking logic out of Trainer. --- profiling/benchmark.py | 6 ++++- vectorlm/trainer.py | 55 ++++++++++-------------------------------- 2 files changed, 18 insertions(+), 43 deletions(-) diff --git a/profiling/benchmark.py b/profiling/benchmark.py index b5300af..c233f27 100644 --- a/profiling/benchmark.py +++ b/profiling/benchmark.py @@ -420,7 +420,11 @@ def load_datasets(self) -> None: file=sys.__stdout__, ): batch = next(train_dl_iterator) - trainer.step(batch, epoch) + num_tokens = len(batch["input_ids"].flatten()) + + with track_time("train_step", {"num_tokens": num_tokens}): + trainer.step(batch, epoch) + profile_handle.step() write_metrics( "torch.cuda.utilization", diff --git a/vectorlm/trainer.py b/vectorlm/trainer.py index 0ca40f3..ff44cf3 100644 --- a/vectorlm/trainer.py +++ b/vectorlm/trainer.py @@ -2,8 +2,7 @@ import math import os -from contextlib import _GeneratorContextManager, contextmanager -from typing import Any, Callable, Generator +from typing import Any import peft import torch @@ -30,18 +29,6 @@ ) -@contextmanager -def _timer_placeholder( - _: str, - __: dict[str, Any] | None = None, -) -> Generator[None, None, None]: - try: - yield # start code block - finally: - # run before exiting - pass - - class Trainer: """Main trainer class. @@ -74,10 +61,6 @@ def __init__( config: Config, enable_wandb_logging: bool, original_dataset_length: int, - timer_handle: Callable[ - [str, dict[str, Any] | None], - _GeneratorContextManager[None], - ] = _timer_placeholder, ) -> None: """Initialize the Trainer class. @@ -87,7 +70,6 @@ def __init__( enable_wandb_logging: Whether to enable wandb logging. original_dataset_length: The length of the original dataset (divided by the batch size). - timer_handle: Optional context manager for profiling. """ self.config = config @@ -107,7 +89,6 @@ def __init__( self.num_update_steps_per_epoch = None self.max_steps = None self.saving_steps = None - self.timer_handle = timer_handle self._post_process(original_dataset_length) if hasattr(self.config, "lora_peft_config"): @@ -186,21 +167,16 @@ def save_checkpoint(self, epoch: int) -> None: if rank == 0: save_metadata(save_dir, meta_dict) - with self.timer_handle("trainer_save_model", {}): - # Save adapter only if running LoRA. - # Merging adapters into base weights would require gathering - # all weights, which would incur significant overhead. - print(f"type(self.model): {type(self.model)}") - if self.peft_method is peft.utils.peft_types.PeftType.LORA: - save_peft_adapter(self.model, save_dir) - else: - save_model(self.model, save_dir, rank) - - with self.timer_handle("trainer_save_optimizer", {}): - save_optimizer(self.optimizer, self.model, save_dir, rank) + # Save adapter only if running LoRA. + # Merging adapters into base weights would require gathering + # all weights, which would incur significant overhead. + if self.peft_method is peft.utils.peft_types.PeftType.LORA: + save_peft_adapter(self.model, save_dir) + else: + save_model(self.model, save_dir, rank) - with self.timer_handle("train_save_scheduler", {}): - save_scheduler(self.lr_scheduler, save_dir, rank) + save_optimizer(self.optimizer, self.model, save_dir, rank) + save_scheduler(self.lr_scheduler, save_dir, rank) dist.barrier() @@ -287,9 +263,7 @@ def step( ): self.save_checkpoint(epoch) - num_tokens = len(train_batch["input_ids"].flatten()) - with self.timer_handle("train_step", {"num_tokens": num_tokens}): - train_loss = self.train_step(train_batch, epoch) + train_loss = self.train_step(train_batch, epoch) test_loss = None if self.tr_step % self.logging_steps == 0: @@ -373,12 +347,9 @@ def eval_step(self, epoch: int) -> float: with torch.no_grad(): batch.pop("id") batch["input_ids"] = batch["input_ids"].type(torch.LongTensor) - num_tokens = len(batch["input_ids"].flatten()) batch["labels"] = batch["labels"].type(torch.LongTensor) - - with self.timer_handle("eval_step", {"num_tokens": num_tokens}): - out = self.model(**batch) - eval_loss += out.loss + out = self.model(**batch) + eval_loss += out.loss gathered_eval_loss = _gather(eval_loss.reshape(1)).mean().item() mean_eval_loss = gathered_eval_loss / len(self.dataset.eval_dataloader) From 5babf6ba99f768358e8bf962cf59141f5d2a843b Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Wed, 17 Apr 2024 09:53:43 -0400 Subject: [PATCH 62/89] Eliminated hasattr check related to no_sync since FSDP is always enabled. --- vectorlm/trainer.py | 25 ++++--------------------- 1 file changed, 4 insertions(+), 21 deletions(-) diff --git a/vectorlm/trainer.py b/vectorlm/trainer.py index ff44cf3..30165eb 100644 --- a/vectorlm/trainer.py +++ b/vectorlm/trainer.py @@ -287,36 +287,19 @@ def train_step(self, batch: dict[str, torch.Tensor], epoch: int) -> float: self.dataset.update_processed_ids(ids) if (self.tr_step + 1) % self.gas != self.gas - 1: - if hasattr(self.model, "no_sync"): - # fsdp: no need to sync while accumulating gradients - with self.model.no_sync(): - out = self.model(**batch) - tr_step_loss = out.loss - (tr_step_loss / self.gas).backward() - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), - self.config.max_grad_norm, - ) - else: - # non-fsdp + # no need to sync while accumulating gradients + with self.model.no_sync(): out = self.model(**batch) tr_step_loss = out.loss (tr_step_loss / self.gas).backward() - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), - self.config.max_grad_norm, - ) - + self.model.clip_grad_norm_(self.config.max_grad_norm) else: # next forward / backward pass will be synced dist.barrier() out = self.model(**batch) tr_step_loss = out.loss (tr_step_loss / self.gas).backward() - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), - self.config.max_grad_norm, - ) + self.model.clip_grad_norm_(self.config.max_grad_norm) self.optimizer.step() if isinstance(self.lr_scheduler, ReduceLROnPlateau): self.lr_scheduler.step(self.metric) From c1b31c4b9d200105c376ece36b7ad89002ed035c Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Wed, 17 Apr 2024 11:34:52 -0400 Subject: [PATCH 63/89] Replaced peft fsdp_auto_wrap_policy to eliminate implicit `accelerate` dependency. Eliminated redundant bfloat16 type conversion. Fixed scope of placeholder for `is_peft_adapter_restored`. --- examples/llama_example.py | 5 ++--- vectorlm/utils/model_utils.py | 37 +++++++++++++++++++++++++++++------ 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/examples/llama_example.py b/examples/llama_example.py index bfa0325..fb7de59 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -84,10 +84,9 @@ def main(config: Config) -> None: "lora_peft_config", None, ) + is_peft_adapter_restored = False if lora_peft_config is not None: - is_peft_adapter_restored = False peft_adapter_path = None - # Restore peft adapter from filesystem if available. if checkpoint_exists(training_args.output_dir): peft_adapter_path = os.path.join( @@ -108,7 +107,7 @@ def main(config: Config) -> None: decoder_layer_module = get_submodule_by_pattern(model, r"DecoderLayer$") assert decoder_layer_module is not None, f"No DecoderLayer found in {model}" model = shard_model( - model.bfloat16(), + model, decoder_layer_module, training_args.use_mp, training_args.use_activation_checkpointing, diff --git a/vectorlm/utils/model_utils.py b/vectorlm/utils/model_utils.py index 5ee0333..2a5fcba 100644 --- a/vectorlm/utils/model_utils.py +++ b/vectorlm/utils/model_utils.py @@ -7,7 +7,6 @@ import torch import torch.distributed as dist from peft import LoraConfig, PeftModel, TaskType, get_peft_model -from peft.utils.other import fsdp_auto_wrap_policy from torch import nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( CheckpointImpl, @@ -18,6 +17,11 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import ( FullyShardedDataParallel as FSDP, ) +from torch.distributed.fsdp.wrap import ( + _or_policy, + lambda_auto_wrap_policy, + transformer_auto_wrap_policy, +) from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -153,7 +157,7 @@ def load_model_and_tokenizer( def fsdp_config( use_mp: bool, - model_to_wrap: nn.Module, + layer_to_wrap: nn.Module, strategy: str, local_rank: int, low_cpu_mem_usage: bool, @@ -163,7 +167,7 @@ def fsdp_config( Args: ---- use_mp: Whether to use mixed-precision. - model_to_wrap: The HuggingFace model to wrap using FSDP. + layer_to_wrap: The layer we are wrapping using FSDP. strategy: The sharding strategy to use. local_rank: The local rank of the current worker. low_cpu_mem_usage: Whether to only load model weights on main rank, and @@ -196,9 +200,31 @@ def _module_init_fn(module: nn.Module) -> Callable: ) ret_dict["mixed_precision"] = mp_policy + transformer_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={layer_to_wrap}, + ) + + def _requires_grad_policy_fn(module: nn.Module) -> bool: + if ( + len(list(module.named_children())) == 0 + and getattr(module, "weight", None) is not None + and module.weight.requires_grad + ): + return True + return False + + lambda_requires_grad_policy = functools.partial( + lambda_auto_wrap_policy, lambda_fn=_requires_grad_policy_fn + ) + + auto_wrap_policy = functools.partial( + _or_policy, + policies=[lambda_requires_grad_policy, transformer_wrap_policy], + ) sharding_strategy = getattr(ShardingStrategy, strategy) - ret_dict["auto_wrap_policy"] = fsdp_auto_wrap_policy(model_to_wrap) + ret_dict["auto_wrap_policy"] = auto_wrap_policy ret_dict["sharding_strategy"] = sharding_strategy ret_dict["device_id"] = torch.cuda.current_device() if low_cpu_mem_usage: @@ -233,11 +259,10 @@ def shard_model( ------- The sharded module with the requested configurations. - """ fsdp_cfg = fsdp_config( use_mp, - model, + layer_to_wrap, strategy, local_rank, low_cpu_mem_usage, From f0b201c9d20a4e3762d2a98d2c219e2d86ff8d83 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Wed, 17 Apr 2024 12:00:37 -0400 Subject: [PATCH 64/89] Configured LoRA auto-wrap policy as off by default- enable the policy only when LoRA is required. --- examples/llama_example.py | 1 + vectorlm/utils/model_utils.py | 53 ++++++++++++++++++++++++----------- 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/examples/llama_example.py b/examples/llama_example.py index fb7de59..41a1dfd 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -114,6 +114,7 @@ def main(config: Config) -> None: training_args.sharding_strategy, local_rank, training_args.low_cpu_mem_usage, + enable_lora=(lora_peft_config is not None), ) # load dataset diff --git a/vectorlm/utils/model_utils.py b/vectorlm/utils/model_utils.py index 2a5fcba..858f6e7 100644 --- a/vectorlm/utils/model_utils.py +++ b/vectorlm/utils/model_utils.py @@ -155,12 +155,31 @@ def load_model_and_tokenizer( return model, tokenizer +def lora_requires_grad_policy_fn(module: nn.Module) -> bool: + """Policy that "turns off" FSDP Flat Param for LoRA-enabled layers. + + FSDP requires consistent requires_grad for each flat param. + + Since LoRA requires_grad tensors are embedded within each layer. + This policy "turns off" FSDP flat param optimization by + requiring a separate flat param block for each tensor. + """ + if ( + len(list(module.named_children())) == 0 + and getattr(module, "weight", None) is not None + and module.weight.requires_grad + ): + return True + return False + + def fsdp_config( use_mp: bool, layer_to_wrap: nn.Module, strategy: str, local_rank: int, low_cpu_mem_usage: bool, + enable_lora: bool = False, ) -> dict[str, Any]: """Get FSDP config. @@ -172,6 +191,7 @@ def fsdp_config( local_rank: The local rank of the current worker. low_cpu_mem_usage: Whether to only load model weights on main rank, and then scatter them to the other workers. + enable_lora: Whether to enable LoRA support. Returns: ------- @@ -205,23 +225,19 @@ def _module_init_fn(module: nn.Module) -> Callable: transformer_layer_cls={layer_to_wrap}, ) - def _requires_grad_policy_fn(module: nn.Module) -> bool: - if ( - len(list(module.named_children())) == 0 - and getattr(module, "weight", None) is not None - and module.weight.requires_grad - ): - return True - return False - - lambda_requires_grad_policy = functools.partial( - lambda_auto_wrap_policy, lambda_fn=_requires_grad_policy_fn - ) + if enable_lora: + # turns off FSDP Flat Param in LoRA layers. + lambda_requires_grad_policy = functools.partial( + lambda_auto_wrap_policy, + lambda_fn=lora_requires_grad_policy_fn, + ) + auto_wrap_policy = functools.partial( + _or_policy, + policies=[lambda_requires_grad_policy, transformer_wrap_policy], + ) + else: + auto_wrap_policy = transformer_wrap_policy - auto_wrap_policy = functools.partial( - _or_policy, - policies=[lambda_requires_grad_policy, transformer_wrap_policy], - ) sharding_strategy = getattr(ShardingStrategy, strategy) ret_dict["auto_wrap_policy"] = auto_wrap_policy @@ -241,6 +257,7 @@ def shard_model( strategy: str, local_rank: int, low_cpu_mem_usage: bool, + enable_lora: bool = False, ) -> nn.Module: """Shard the model to workers using FSDP. @@ -254,6 +271,9 @@ def shard_model( local_rank: The local rank of the current worker. low_cpu_mem_usage: Whether to only load model weights on main rank, and then scatter them to the other workers. + enable_lora: Whether to enable support for LoRA, where only a subset of + parameter tensors requires_grad. Enabling might significantly reduce + training throughput, so enable this only when actually using LoRA. Returns: ------- @@ -266,6 +286,7 @@ def shard_model( strategy, local_rank, low_cpu_mem_usage, + enable_lora, ) if dist.get_rank() == 0: print(f"FSDP config: {fsdp_cfg}") From 429ec5e48eade22c0ed876a344458d5c8bf3712f Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Wed, 17 Apr 2024 12:04:19 -0400 Subject: [PATCH 65/89] Revised punctuation in lora_requires_grad_policy_fn. --- vectorlm/utils/model_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vectorlm/utils/model_utils.py b/vectorlm/utils/model_utils.py index 858f6e7..d57d875 100644 --- a/vectorlm/utils/model_utils.py +++ b/vectorlm/utils/model_utils.py @@ -160,8 +160,8 @@ def lora_requires_grad_policy_fn(module: nn.Module) -> bool: FSDP requires consistent requires_grad for each flat param. - Since LoRA requires_grad tensors are embedded within each layer. - This policy "turns off" FSDP flat param optimization by + Since LoRA requires_grad tensors are embedded within each layer, + this policy "turns off" FSDP flat param optimization by requiring a separate flat param block for each tensor. """ if ( From afbc0617ea901113fd54db4198f9dd0aecb171f6 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Wed, 17 Apr 2024 13:33:17 -0400 Subject: [PATCH 66/89] Renamed declarative `enable_lora` with descriptive `is_lora_enabled`. --- examples/llama_example.py | 2 +- vectorlm/utils/model_utils.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/llama_example.py b/examples/llama_example.py index 41a1dfd..a6986d8 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -114,7 +114,7 @@ def main(config: Config) -> None: training_args.sharding_strategy, local_rank, training_args.low_cpu_mem_usage, - enable_lora=(lora_peft_config is not None), + is_lora_enabled=(lora_peft_config is not None), ) # load dataset diff --git a/vectorlm/utils/model_utils.py b/vectorlm/utils/model_utils.py index d57d875..ba39c49 100644 --- a/vectorlm/utils/model_utils.py +++ b/vectorlm/utils/model_utils.py @@ -179,7 +179,7 @@ def fsdp_config( strategy: str, local_rank: int, low_cpu_mem_usage: bool, - enable_lora: bool = False, + is_lora_enabled: bool = False, ) -> dict[str, Any]: """Get FSDP config. @@ -191,7 +191,7 @@ def fsdp_config( local_rank: The local rank of the current worker. low_cpu_mem_usage: Whether to only load model weights on main rank, and then scatter them to the other workers. - enable_lora: Whether to enable LoRA support. + is_lora_enabled: Whether to enable LoRA support. Returns: ------- @@ -225,7 +225,7 @@ def _module_init_fn(module: nn.Module) -> Callable: transformer_layer_cls={layer_to_wrap}, ) - if enable_lora: + if is_lora_enabled: # turns off FSDP Flat Param in LoRA layers. lambda_requires_grad_policy = functools.partial( lambda_auto_wrap_policy, @@ -257,7 +257,7 @@ def shard_model( strategy: str, local_rank: int, low_cpu_mem_usage: bool, - enable_lora: bool = False, + is_lora_enabled: bool = False, ) -> nn.Module: """Shard the model to workers using FSDP. @@ -271,7 +271,7 @@ def shard_model( local_rank: The local rank of the current worker. low_cpu_mem_usage: Whether to only load model weights on main rank, and then scatter them to the other workers. - enable_lora: Whether to enable support for LoRA, where only a subset of + is_lora_enabled: Whether to enable support for LoRA, where only a subset of parameter tensors requires_grad. Enabling might significantly reduce training throughput, so enable this only when actually using LoRA. @@ -286,7 +286,7 @@ def shard_model( strategy, local_rank, low_cpu_mem_usage, - enable_lora, + is_lora_enabled, ) if dist.get_rank() == 0: print(f"FSDP config: {fsdp_cfg}") From 4936b1d97d82b89b289624e3f7df5d9c1dcaebcd Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 22 Apr 2024 11:38:50 -0400 Subject: [PATCH 67/89] Added (request for comment) AbstractInferenceEngine interface and LoRAInferenceEngine implementation. --- vectorlm/inference/__init__.py | 2 + vectorlm/inference/abstract.py | 64 +++++++++++++ vectorlm/inference/inference_lora.py | 133 +++++++++++++++++++++++++++ vectorlm/utils/save_utils.py | 2 +- 4 files changed, 200 insertions(+), 1 deletion(-) create mode 100644 vectorlm/inference/__init__.py create mode 100644 vectorlm/inference/abstract.py create mode 100644 vectorlm/inference/inference_lora.py diff --git a/vectorlm/inference/__init__.py b/vectorlm/inference/__init__.py new file mode 100644 index 0000000..9289b3e --- /dev/null +++ b/vectorlm/inference/__init__.py @@ -0,0 +1,2 @@ +from .abstract import AbstractInferenceEngine +from .inference_lora import LoRAInferenceEngine diff --git a/vectorlm/inference/abstract.py b/vectorlm/inference/abstract.py new file mode 100644 index 0000000..938f09f --- /dev/null +++ b/vectorlm/inference/abstract.py @@ -0,0 +1,64 @@ +"""Wrapper around inference engine. + +Provides the following functionalities: +- Batch inference +- LoRA state tracking +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod + +import vllm + +from vectorlm.trainer import Trainer + + +class AbstractInferenceEngine(ABC): + """Interface for the inference engine.""" + + def __init__( + self, + trainer: Trainer, + sampling_params: vllm.SamplingParams | None = None, + ) -> None: + """Initialize inference engine. + + Params: + trainer: Trainer instance. + sampling_params: Optionally, specify default sampling params. + + """ + self.trainer = trainer + self.sampling_params = sampling_params + + def update(self, trainer: Trainer | None = None) -> None: + """Inform the inference engine that the model in trainer is updated. + + Params: + trainer: Optionally, replace self.trainer with the provided value. + """ + if trainer is not None: + self.trainer = trainer + + @abstractmethod + def generate( + self, + prompts: list[str], + sampling_params: vllm.SamplingParams | None = None, + ) -> list[list[vllm.CompletionOutput]]: + """Generate continuation for the given prompts synchronously. + + Params: + ------ + prompts: List of input prompts. + sampling_params: Optionally, override self.sampling_params in + this request only. + + Returns + ------- + Output from vllm: list[list[vllm.CompletionOutput]] + outer layer: one for each prompt. + inner layer: one for each output option for the prompt. + + """ diff --git a/vectorlm/inference/inference_lora.py b/vectorlm/inference/inference_lora.py new file mode 100644 index 0000000..6d0f1c6 --- /dev/null +++ b/vectorlm/inference/inference_lora.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +import os + +import torch +import torch.distributed as dist +import vllm +from vllm.lora.request import LoRARequest + +from vectorlm.trainer import Trainer +from vectorlm.utils.save_utils import save_peft_adapter + +from .abstract import AbstractInferenceEngine + + +class LoRAInferenceEngine(AbstractInferenceEngine): + """Inference engine optimized for inference during LoRA PEFT training.""" + + def __init__( + self, + trainer: Trainer, + sampling_params: vllm.SamplingParams | None = None, + base_model_name: str | None = None, + tensor_parallel_size: int = 1, + gpu_memory_utilization: float = 0.3, + adapter_temp_folder: str | None = None, + ) -> None: + """Initialize inference engine. + + Params: + trainer: Trainer instance. + sampling_params: Optionally, specify default sampling params. + base_model_name: Path or HuggingFace repo name of base model. + tensor_parallel_size: Forwarded to vllm.LLM. + gpu_memory_utilization: Forwarded to vllm.LLM. + adapter_temp_folder: Temporary path where temporary adapter weights + are saved. If not specified, f`/dev/shm/{job_id}` + """ + if dist.get_rank() != 0: + return + + self.sampling_params = sampling_params + + if adapter_temp_folder is not None: + self.adapter_temp_folder = adapter_temp_folder + else: + slurm_job_id_or_placeholder = os.environ.get("SLURM_JOB_ID", "0") + + # Manually specify the in-memory /dev/shm filesystem + # to avoid disk wear and overhead. + self.adapter_base_folder = "/dev/shm/" # noqa: S108 + self.adapter_temp_folder = os.path.join( + self.adapter_base_folder, + slurm_job_id_or_placeholder, + ) + + assert ( + base_model_name is not None + ), "base_model_name is required when instantiating LoRAInferenceEngine." + + self.vllm_llm = vllm.LLM( + base_model_name, + tensor_parallel_size=tensor_parallel_size, + gpu_memory_utilization=gpu_memory_utilization, + enable_lora=True, + ) + + # Trigger FSDP initialization before + _wrapped_model = trainer.model + assert _wrapped_model is not None + _wrapped_model(input_ids=torch.zeros((1, 1), dtype=torch.int)) + self.vllm_train_step = -1 + + self.update(trainer) + + def update(self, trainer: Trainer | None = None) -> None: + """Inform the inference engine that the model in trainer is updated. + + Params: + trainer: Optionally, replace self.trainer with the provided value. + """ + if dist.get_rank() != 0: + return + + if trainer is not None: + self.trainer = trainer + + wrapped_model = self.trainer.model + assert wrapped_model is not None + + if self.vllm_train_step != self.trainer.tr_step: + save_peft_adapter(wrapped_model, self.adapter_temp_folder) + assert self.trainer.tr_step is not None + assert self.trainer.tr_step >= 0 + self.vllm_train_step = self.trainer.tr_step + self.lora_request = LoRARequest( + "_vectorlm", + self.vllm_train_step + 1, + self.adapter_temp_folder, + ) + + def generate( + self, + prompts: list[str], + sampling_params: vllm.SamplingParams | None = None, + ) -> list[list[vllm.CompletionOutput]]: + """Generate continuation for the given prompts. Invoke only on rank 0. + + Params: + ------ + prompts: List of input prompts. + sampling_params: Optionally, override self.sampling_params in + this request only. + + Returns + ------- + Output from vllm: list[list[vllm.CompletionOutput]] + outer layer: one for each prompt. + inner layer: one for each output option for the prompt. + + """ + if dist.get_rank() != 0: + msg = "LoRA inference engine is supported only on rank 0." + raise RuntimeError(msg) + + assert self.vllm_train_step is not None + output_list = self.vllm_llm.generate( + prompts, + sampling_params, + lora_request=self.lora_request, + use_tqdm=True, + ) + return [output.outputs for output in output_list] diff --git a/vectorlm/utils/save_utils.py b/vectorlm/utils/save_utils.py index d064571..44b580e 100644 --- a/vectorlm/utils/save_utils.py +++ b/vectorlm/utils/save_utils.py @@ -192,7 +192,7 @@ def get_peft_adapter_tensor_dict( def save_peft_adapter( - model: peft.peft_model.PeftModel, + model: peft.peft_model.PeftModel | nn.Module, output_path: str, ) -> None: """Save peft adapter to filesystem in a FSDP environment.""" From aa1fe8b86cd55e3c3b939b2e613d2a15a6032c5f Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Mon, 22 Apr 2024 11:49:31 -0400 Subject: [PATCH 68/89] Renamed "inference" to "sampling". Misc documentation updates. --- vectorlm/inference/__init__.py | 2 -- vectorlm/sampling/__init__.py | 2 ++ vectorlm/{inference => sampling}/abstract.py | 15 +++++---------- .../sampling_lora.py} | 14 +++++++------- 4 files changed, 14 insertions(+), 19 deletions(-) delete mode 100644 vectorlm/inference/__init__.py create mode 100644 vectorlm/sampling/__init__.py rename vectorlm/{inference => sampling}/abstract.py (81%) rename vectorlm/{inference/inference_lora.py => sampling/sampling_lora.py} (90%) diff --git a/vectorlm/inference/__init__.py b/vectorlm/inference/__init__.py deleted file mode 100644 index 9289b3e..0000000 --- a/vectorlm/inference/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .abstract import AbstractInferenceEngine -from .inference_lora import LoRAInferenceEngine diff --git a/vectorlm/sampling/__init__.py b/vectorlm/sampling/__init__.py new file mode 100644 index 0000000..ebdc603 --- /dev/null +++ b/vectorlm/sampling/__init__.py @@ -0,0 +1,2 @@ +from .abstract import AbstractSamplingEngine +from .sampling_lora import LoRASamplingEngine diff --git a/vectorlm/inference/abstract.py b/vectorlm/sampling/abstract.py similarity index 81% rename from vectorlm/inference/abstract.py rename to vectorlm/sampling/abstract.py index 938f09f..51ed6f8 100644 --- a/vectorlm/inference/abstract.py +++ b/vectorlm/sampling/abstract.py @@ -1,9 +1,4 @@ -"""Wrapper around inference engine. - -Provides the following functionalities: -- Batch inference -- LoRA state tracking -""" +"""Wrapper around sampling engine.""" from __future__ import annotations @@ -14,15 +9,15 @@ from vectorlm.trainer import Trainer -class AbstractInferenceEngine(ABC): - """Interface for the inference engine.""" +class AbstractSamplingEngine(ABC): + """Interface for the sampling engine.""" def __init__( self, trainer: Trainer, sampling_params: vllm.SamplingParams | None = None, ) -> None: - """Initialize inference engine. + """Initialize sampling engine. Params: trainer: Trainer instance. @@ -33,7 +28,7 @@ def __init__( self.sampling_params = sampling_params def update(self, trainer: Trainer | None = None) -> None: - """Inform the inference engine that the model in trainer is updated. + """Inform the sampling engine that the model in trainer is updated. Params: trainer: Optionally, replace self.trainer with the provided value. diff --git a/vectorlm/inference/inference_lora.py b/vectorlm/sampling/sampling_lora.py similarity index 90% rename from vectorlm/inference/inference_lora.py rename to vectorlm/sampling/sampling_lora.py index 6d0f1c6..d4cc00d 100644 --- a/vectorlm/inference/inference_lora.py +++ b/vectorlm/sampling/sampling_lora.py @@ -10,11 +10,11 @@ from vectorlm.trainer import Trainer from vectorlm.utils.save_utils import save_peft_adapter -from .abstract import AbstractInferenceEngine +from .abstract import AbstractSamplingEngine -class LoRAInferenceEngine(AbstractInferenceEngine): - """Inference engine optimized for inference during LoRA PEFT training.""" +class LoRASamplingEngine(AbstractSamplingEngine): + """Sampling engine optimized for LoRA PEFT.""" def __init__( self, @@ -25,7 +25,7 @@ def __init__( gpu_memory_utilization: float = 0.3, adapter_temp_folder: str | None = None, ) -> None: - """Initialize inference engine. + """Initialize sampling engine. Params: trainer: Trainer instance. @@ -56,7 +56,7 @@ def __init__( assert ( base_model_name is not None - ), "base_model_name is required when instantiating LoRAInferenceEngine." + ), "base_model_name is required when instantiating LoRASamplingEngine." self.vllm_llm = vllm.LLM( base_model_name, @@ -74,7 +74,7 @@ def __init__( self.update(trainer) def update(self, trainer: Trainer | None = None) -> None: - """Inform the inference engine that the model in trainer is updated. + """Inform the sampling engine that the model in trainer is updated. Params: trainer: Optionally, replace self.trainer with the provided value. @@ -120,7 +120,7 @@ def generate( """ if dist.get_rank() != 0: - msg = "LoRA inference engine is supported only on rank 0." + msg = "LoRA sampling engine is supported only on rank 0." raise RuntimeError(msg) assert self.vllm_train_step is not None From 675367b1ddb28e2a0c0224ed23537c7d9743c629 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Thu, 25 Apr 2024 16:33:24 -0400 Subject: [PATCH 69/89] Added reference sampling steps to llama_example. Added example sampling configs and documentations. --- configs/config.yaml | 14 ++++++- docs/config.md | 16 ++++++++ examples/llama_example.py | 35 ++++++++++++++-- vectorlm/sampling/utils.py | 76 +++++++++++++++++++++++++++++++++++ vectorlm/utils/model_utils.py | 7 ++-- 5 files changed, 140 insertions(+), 8 deletions(-) create mode 100644 vectorlm/sampling/utils.py diff --git a/configs/config.yaml b/configs/config.yaml index a2a03e7..6758a2b 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -21,7 +21,7 @@ train_parameters: low_cpu_mem_usage: True # LoRA config: uncomment the block below to enable LoRA - + # lora_peft_config: # task_type: CAUSAL_LM # inference_mode: False @@ -29,7 +29,6 @@ train_parameters: # lora_alpha: 32 # lora_dropout: 0.1 - # Gradient norm clipping max_grad_norm: 1 gradient_accumulation_steps: 4 @@ -50,6 +49,17 @@ train_parameters: logging_steps: 500 save_frequency: 0.25 + # Sampling during training + # Uncomment the block below to enable. + + # sampler: + # sample_frequency: 8 + # output_jsonl_path: data/output.jsonl + # prompts: + # - "Vector Institute of the" + # - "Vector Institute is located in the city of" + # - "The answer to life the universe and everything is" + dataset: ignore_index: -100 eval_bs: 8 diff --git a/docs/config.md b/docs/config.md index ef0de37..197839c 100644 --- a/docs/config.md +++ b/docs/config.md @@ -51,6 +51,22 @@ Similar to the wandb config above, these keyword parameters are fed directly int * `logging_steps`: How often evaluation is run using the evaluation dataset. * `save_frequency`: The frequency at which checkpointing occurs. This must be between 0 and 1. + +### Sampling during Training + +To disable sampling during training, delete the entire "sampling" section. + +* `sample_frequency`: Number of train steps between two consecutive sampling steps. +* `output_jsonl_path`: Optional; write sampled output to the specified jsonl file. +* `prompts`: YAML list of prompt strings. + +Each line of the output jsonl file would be a dictionary with keys: + +* `tr_step`: number (integer), trainer step when this line was generated. +* `prompt`: string. +* `options`: list of strings, one for each possible option that the sampler provided. +* `time_taken`: float, number of seconds taken to generate **all** prompts at this step. + ## Dataset * `ignore_index`: The integer index used to ignore a given token in the loss calculation. Cross-entropy loss by default uses `-100`. diff --git a/examples/llama_example.py b/examples/llama_example.py index a6986d8..852294f 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -11,8 +11,11 @@ from torch.optim import AdamW from tqdm import tqdm from transformers import set_seed +from vllm import SamplingParams from vectorlm.dataset import Dataset +from vectorlm.sampling import LoRASamplingEngine +from vectorlm.sampling.utils import handle_sample from vectorlm.trainer import Trainer from vectorlm.utils.data_utils import Config from vectorlm.utils.misc_utils import cleanup, setup, wandb_setup @@ -27,6 +30,7 @@ checkpoint_exists, get_latest_checkpoint_dir, save_consolidated_model, + save_peft_adapter, ) @@ -159,10 +163,17 @@ def main(config: Config) -> None: # If no checkpoint, it returns 0. checkpointed_epoch = trainer.find_checkpoint(training_args.output_dir) + if training_args.sampler is not None: + sampling_engine = LoRASamplingEngine( + trainer, + SamplingParams(seed=0), + gpu_memory_utilization=0.3, + base_model_name=config.model, + ) + for epoch in range(checkpointed_epoch, training_args.epochs): - trainer.model.train() train_dl_iterator = iter(dataset.train_dataloader) - for _ in tqdm( + for index in tqdm( range(len(dataset.train_dataloader)), disable=rank != 0, file=sys.__stdout__, @@ -170,6 +181,19 @@ def main(config: Config) -> None: batch = next(train_dl_iterator) trainer.step(batch, epoch) + if ( + (training_args.sampler is not None) + and (index % training_args.sampler.sample_frequency == 0) + and (dist.get_rank() == 0) + ): + sampling_engine.update(trainer) + handle_sample( + sampling_engine, + training_args.sampler.prompts, + training_args.sampler.output_jsonl_path, + extra_data={"tr_step": trainer.tr_step}, + ) + if epoch == training_args.epochs - 1: hf_save_dir = os.path.join(training_args.output_dir, "final-model") else: @@ -179,7 +203,12 @@ def main(config: Config) -> None: f"epoch_{epoch}", "end-epoch-model", ) - save_consolidated_model(trainer.model, hf_save_dir, rank) + # Save base (consolidated) model only when not running peft. + if lora_peft_config is None: + save_consolidated_model(trainer.model, hf_save_dir, rank) + else: + save_peft_adapter(trainer.model, hf_save_dir) + dataset.reset_dataloaders() diff --git a/vectorlm/sampling/utils.py b/vectorlm/sampling/utils.py new file mode 100644 index 0000000..9bc3a11 --- /dev/null +++ b/vectorlm/sampling/utils.py @@ -0,0 +1,76 @@ +"""Generic utils for the sampling engines.""" + +from __future__ import annotations + +import json +import time +from typing import Any, Iterable, NamedTuple + +from vllm import SamplingParams + +from .abstract import AbstractSamplingEngine + + +class SampleOutput(NamedTuple): + """Represents possible responses to a prompt. + + Params: + prompt: prompt string. + options: list of proposed responses to this prompt. + """ + + prompt: str + options: list[str] + time_taken: float + + +def handle_sample( + sampling_engine: AbstractSamplingEngine, + prompts: Iterable[str], + output_path: str | None, + sampling_params: SamplingParams | None = None, + extra_data: dict[str, Any] | None = None, +) -> list[SampleOutput]: + """Sample continuations and optionally save to disk. + + Params: + ------ + sampling_engine: an instantiation of sampling engine. + prompts: a list (iterable) of prompts. + output_path: if provided, append output json lines to this file. + sampling_params: forwarded to sampling engine. + extra_data: prepended to each line of output (e.g., current epoch.) + + Returns + ------- + List of SampleOutput, representing continuations for each prompt. + + """ + _prompts = list(prompts) + + start_time = time.time() + generation_output = sampling_engine.generate(_prompts, sampling_params) + time_taken = time.time() - start_time + + # Parse sample engine output and keep only the output strings. + sample_outputs: list[SampleOutput] = [] + for prompt, options in zip(prompts, generation_output): + sample_outputs.append( + SampleOutput( + prompt, + [option.text for option in options], + time_taken, + ), + ) + + # note: always produce jsonl_output_lines to ensure code coverage. + extra_data = extra_data if extra_data is not None else {} + jsonl_output_lines: list[str] = [ + json.dumps({**extra_data, **sample_output._asdict()}) + for sample_output in sample_outputs + ] + if output_path is not None: + with open(output_path, "a") as output_jsonl_file: + output_jsonl_file.write("\n".join(jsonl_output_lines) + "\n\n") + + return sample_outputs diff --git a/vectorlm/utils/model_utils.py b/vectorlm/utils/model_utils.py index ba39c49..31556fe 100644 --- a/vectorlm/utils/model_utils.py +++ b/vectorlm/utils/model_utils.py @@ -271,9 +271,10 @@ def shard_model( local_rank: The local rank of the current worker. low_cpu_mem_usage: Whether to only load model weights on main rank, and then scatter them to the other workers. - is_lora_enabled: Whether to enable support for LoRA, where only a subset of - parameter tensors requires_grad. Enabling might significantly reduce - training throughput, so enable this only when actually using LoRA. + is_lora_enabled: Whether to enable support for LoRA, where only a subset + of parameter tensors requires_grad. Enabling might significantly + reduce training throughput, so enable this only when actually using + LoRA. Returns: ------- From ca2cad8e0151957d40ad0da829ccae78934e0852 Mon Sep 17 00:00:00 2001 From: Jacob-Junqi Tian Date: Thu, 25 Apr 2024 17:14:40 -0400 Subject: [PATCH 70/89] Added train_parameters.get("sampler"). --- examples/llama_example.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/llama_example.py b/examples/llama_example.py index 852294f..b860700 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -88,6 +88,7 @@ def main(config: Config) -> None: "lora_peft_config", None, ) + sampler_config = config.train_parameters.get("sampler") is_peft_adapter_restored = False if lora_peft_config is not None: peft_adapter_path = None @@ -163,12 +164,13 @@ def main(config: Config) -> None: # If no checkpoint, it returns 0. checkpointed_epoch = trainer.find_checkpoint(training_args.output_dir) - if training_args.sampler is not None: + if sampler_config is not None: sampling_engine = LoRASamplingEngine( trainer, SamplingParams(seed=0), gpu_memory_utilization=0.3, base_model_name=config.model, + tensor_parallel_size=world_size, ) for epoch in range(checkpointed_epoch, training_args.epochs): @@ -182,7 +184,7 @@ def main(config: Config) -> None: trainer.step(batch, epoch) if ( - (training_args.sampler is not None) + (sampler_config is not None) and (index % training_args.sampler.sample_frequency == 0) and (dist.get_rank() == 0) ): @@ -193,6 +195,7 @@ def main(config: Config) -> None: training_args.sampler.output_jsonl_path, extra_data={"tr_step": trainer.tr_step}, ) + dist.barrier() if epoch == training_args.epochs - 1: hf_save_dir = os.path.join(training_args.output_dir, "final-model") From 649a4b878b36eb211898cf5df32a052f336fe780 Mon Sep 17 00:00:00 2001 From: jacobthebanana Date: Mon, 6 May 2024 08:59:51 -0400 Subject: [PATCH 71/89] [WIP] Implemented vLLM wrapper combining vectorlm and vLLM workers. --- .gitignore | 3 +- configs/config_gemma.yaml | 68 +++++++++++ examples/llama_example.py | 63 +++++----- examples/llama_example_mp.py | 208 ++++++++++++++++++++++++++++++++ examples/train_and_inference.py | 167 +++++++++++++++++++++++++ profiling/launch_benchmark.py | 2 +- vectorlm/sampling/utils.py | 145 +++++++++++++++++++++- 7 files changed, 621 insertions(+), 35 deletions(-) create mode 100644 configs/config_gemma.yaml create mode 100644 examples/llama_example_mp.py create mode 100644 examples/train_and_inference.py diff --git a/.gitignore b/.gitignore index 05c42cc..f4c948d 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ data/ **/*.pyc /.cache /.vscode -/data \ No newline at end of file +/data +/env diff --git a/configs/config_gemma.yaml b/configs/config_gemma.yaml new file mode 100644 index 0000000..6668cbd --- /dev/null +++ b/configs/config_gemma.yaml @@ -0,0 +1,68 @@ +model: google/gemma-2b +enable_wandb_logging: False + +wandb_config: + project: vector-lm-verify + name: benchmark-lora + # tags: ["20240418-1a-preemption"] + +train_parameters: + output_dir: /network/scratch/j/jacob-junqi.tian/vectorlm/weights + max_seq_len: 128 + epochs: 10 + seed: 11 + + # Sharding strategy + sharding_strategy: FULL_SHARD + + # Memory + use_mp: True + use_activation_checkpointing: True + # use_flash_attention is automatically enabled + # for CUDA capability > 8.0 + use_flash_attention: False + low_cpu_mem_usage: True + + lora_peft_config: + task_type: CAUSAL_LM + inference_mode: False + r: 8 + lora_alpha: 32 + lora_dropout: 0.1 + + # Gradient norm clipping + max_grad_norm: 1 + gradient_accumulation_steps: 4 + + # Optimizer + optimizer: + lr: 5.0e-5 + weight_decay: 0.1 + betas: [0.9, 0.95] + eps: 1.0e-5 + + # Scheduler + lr_scheduler_type: cosine + warmup_ratio: 0.05 + + # Checkpointing + checkpointing_enabled: False + logging_steps: 10 + save_frequency: 0.10 + + # Sampling during training + sampler: + sample_frequency: 8 + output_jsonl_path: data/output-5e-5-2b.jsonl + vllm_dtype: half + prompts: + - "Vector Institute of the" + - "Vector Institute is located in the city of" + - "The answer to life the universe and everything is" + +dataset: + ignore_index: -100 + eval_bs: 8 + train_bs: 8 + train_ds: data/processed/vector-west/train + eval_ds: data/processed/vector-west/test diff --git a/examples/llama_example.py b/examples/llama_example.py index b860700..fb6539f 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -1,21 +1,21 @@ from __future__ import annotations import argparse +import logging import math import os import sys from argparse import Namespace +from threading import Barrier +from typing import Callable import torch import torch.distributed as dist from torch.optim import AdamW from tqdm import tqdm from transformers import set_seed -from vllm import SamplingParams from vectorlm.dataset import Dataset -from vectorlm.sampling import LoRASamplingEngine -from vectorlm.sampling.utils import handle_sample from vectorlm.trainer import Trainer from vectorlm.utils.data_utils import Config from vectorlm.utils.misc_utils import cleanup, setup, wandb_setup @@ -51,7 +51,13 @@ def parse_args() -> Namespace: return parser.parse_args() -def main(config: Config) -> None: +def main( + config: Config, + local_rank: int | None = None, + world_size: int | None = None, + dist_init_barrier: Barrier | None = None, + vllm_init_callback: Callable[[], None] | None = None, +) -> None: """Define the main calling function.""" training_args = config.train_parameters @@ -59,9 +65,20 @@ def main(config: Config) -> None: set_seed(training_args.seed) # set CUDA related dependencies - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) + if (local_rank is None) or (world_size is None): + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + else: + rank = local_rank # modify if going beyond one node. + os.environ["LOCAL_RANK"] = str(local_rank) + os.environ["RANK"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(world_size) + + logging.info( + "dist.init_process_group", + extra={"local_rank": local_rank, "world_size": world_size}, + ) print(f"Rank: {rank}, World size: {world_size}") if dist.is_initialized(): @@ -71,7 +88,6 @@ def main(config: Config) -> None: # setup wandb if rank == 0: wandb_setup(config, **config.wandb_config) - dist.barrier() # load model and tokenizer model, tokenizer = load_model_and_tokenizer( @@ -164,15 +180,6 @@ def main(config: Config) -> None: # If no checkpoint, it returns 0. checkpointed_epoch = trainer.find_checkpoint(training_args.output_dir) - if sampler_config is not None: - sampling_engine = LoRASamplingEngine( - trainer, - SamplingParams(seed=0), - gpu_memory_utilization=0.3, - base_model_name=config.model, - tensor_parallel_size=world_size, - ) - for epoch in range(checkpointed_epoch, training_args.epochs): train_dl_iterator = iter(dataset.train_dataloader) for index in tqdm( @@ -183,19 +190,15 @@ def main(config: Config) -> None: batch = next(train_dl_iterator) trainer.step(batch, epoch) - if ( - (sampler_config is not None) - and (index % training_args.sampler.sample_frequency == 0) - and (dist.get_rank() == 0) - ): - sampling_engine.update(trainer) - handle_sample( - sampling_engine, - training_args.sampler.prompts, - training_args.sampler.output_jsonl_path, - extra_data={"tr_step": trainer.tr_step}, - ) - dist.barrier() + # if ( + # (sampler_config is not None) + # and (index % training_args.sampler.sample_frequency == 0) + # and (dist.get_rank() == 0) + # ): + # output = vllm_llm.generate(training_args.sampler.prompts) + # print(output) + + # dist.barrier() if epoch == training_args.epochs - 1: hf_save_dir = os.path.join(training_args.output_dir, "final-model") diff --git a/examples/llama_example_mp.py b/examples/llama_example_mp.py new file mode 100644 index 0000000..15de64e --- /dev/null +++ b/examples/llama_example_mp.py @@ -0,0 +1,208 @@ +"""llama_example, but uses multiprocessing in place of torchrun""" + +from __future__ import annotations + +import argparse +import logging +import multiprocessing +import multiprocessing.context +import multiprocessing.managers +import threading +from functools import partial +from typing import TYPE_CHECKING, Callable + +if TYPE_CHECKING: + from vllm.worker.worker_base import WorkerBase + +from llama_example import main +from vllm.engine.arg_utils import EngineArgs, EngineConfig +from vllm.engine.llm_engine import LLMEngine +from vllm.engine.local_worker_utils import LocalWorkerVllm, ResultHandler +from vllm.entrypoints.llm import LLM +from vllm.worker.worker import init_worker_distributed_environment + +from vectorlm.sampling.utils import ( + ManagedLLM, + ManagedMultiProcGPUExecutor, + _ensure_torch_dist_is_initialized, + _get_rdvz_url, + get_vllm_worker_factory, +) +from vectorlm.utils.data_utils import Config + +logging.basicConfig(level=logging.DEBUG) + +mp = multiprocessing.get_context("fork") + + +class _VLLMCallbackWrapper: + """Provide vLLM Engine access to multiprocess.Process workers. + + vLLM engine is initialized only after the initialize_engine call. + """ + + def __init__( + self, + non_driver_workers: list[VectorLMWorker], + engine_config: EngineConfig, + vectorlm_config: Config, + world_size: int, + ) -> None: + """Instantiate class without initializing wrapped vLLM engine.""" + self.llm_engine: LLMEngine | None = None + self.llm: LLM | None = None + self.non_driver_workers = non_driver_workers + self.engine_config = engine_config + + # torch.dist init barrier for rank 0 vectorlm process. + # ensures rank 0 vectorlm achieves torch.dist + # before starting rank 0 Worker. + self.root_vectorlm_dist_init_barrier = threading.Barrier(2) + self.vectorlm_main_fn = partial( + main, + vectorlm_config, + 0, + world_size, + self.root_vectorlm_dist_init_barrier, + ) + + def initialize_engine(self) -> None: + """Initialize vLLM engine. + + Invoke this method only after vLLM workers are all ready. + """ + ManagedMultiProcGPUExecutor.workers = tuple( + self.non_driver_workers, + ) + ManagedMultiProcGPUExecutor.vectorlm_main_fn = self.vectorlm_main_fn + ManagedMultiProcGPUExecutor.vectorlm_dist_init_barrier = ( + self.root_vectorlm_dist_init_barrier + ) + + self.llm_engine = LLMEngine( + **self.engine_config.to_dict(), + executor_class=ManagedMultiProcGPUExecutor, + log_stats=False, + ) + self.llm = ManagedLLM(self.llm_engine) + print(f"Instantiated ManagedLLM: {self.llm}") + + +class VectorLMWorker(LocalWorkerVllm): + """Worker for running VectorLM logic alongside vLLM worker. + + Important: do not use this instance for the rank 0 (root) process. + + Note that nccl requires that only one process may have access + to each GPU. Each LocalWorkerVllm is a multiprocessing.Process. + Vectorlm logic would be launched as a thread within each of these + proceses. + + Spawn no more than one such instance for each GPU. + """ + + def __init__( + self, + result_handler: ResultHandler, + worker_factory: Callable[[], WorkerBase], + vllm_engine_config: EngineConfig, + vectorlm_config: Config, + local_rank: int, + world_size: int, + ) -> None: + """Instantiate LocalWorkerVllm wrapper. + + vectorlm_dist_init_barrier ensures that torch.dist is initialized in + the vectorlm thread and not the main thread (vllm) of the process. + """ + self.vllm_engine_config = vllm_engine_config + self.vectorlm_dist_init_barrier = threading.Barrier(2) + self.vectorlm_config = vectorlm_config + self.local_rank = local_rank + self.world_size = world_size + self.vllm_init_callback: Callable[[], None] | None = None + + super().__init__(result_handler, worker_factory) + + def run(self) -> None: + """Launch vectorlm logic in a separate thread.""" + print(f"rank {self.local_rank}: init_worker_dist started") + init_worker_distributed_environment( + self.vllm_engine_config.parallel_config, + self.local_rank, + _get_rdvz_url(), + self.local_rank, + ) + print(f"rank {self.local_rank}: init_worker_dist completed") + + _ensure_torch_dist_is_initialized() + + self.vectorlm_thread = threading.Thread( + target=main, + args=( + self.vectorlm_config, + self.local_rank, + self.world_size, + self.vectorlm_dist_init_barrier, + self.vllm_init_callback, + ), + name=f"rank-{self.local_rank}/vectorlm", + ) + self.vectorlm_thread.start() + + super().run() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--world_size", type=int, default=1) + parser.add_argument("--yaml_path", type=str, required=True) + args = parser.parse_args() + + world_size: int = args.world_size + vectorlm_config = Config(yaml_path=args.yaml_path) + sampler_config = vectorlm_config.train_parameters.sampler # type: ignore[] + vllm_engine_config = EngineArgs( + model=vectorlm_config.model, # type: ignore[] + gpu_memory_utilization=sampler_config.get( + "gpu_memory_utilization", + 0.35, + ), + tensor_parallel_size=world_size, + dtype=sampler_config.get("vllm_dtype", "auto"), + ).create_engine_config() + + vllm_result_handler = ResultHandler() + + # rank 0 worker runs in the __main__ process. + # all other ranks use one process each. + # vectorlm logic in each ranks (including rank 0) is in a separate thread. + non_driver_workers: list[VectorLMWorker] = [ + VectorLMWorker( + vllm_result_handler, + get_vllm_worker_factory( + vllm_engine_config, + _get_rdvz_url(), + local_rank, + ), + vllm_engine_config, + vectorlm_config, + local_rank, + world_size=world_size, + ) + for local_rank in range(1, world_size) + ] + vllm_callback_wrapper = _VLLMCallbackWrapper( + non_driver_workers, + vllm_engine_config, + vectorlm_config, + world_size, + ) + + for worker in non_driver_workers: + worker.start() + + vllm_callback_wrapper.initialize_engine() + assert vllm_callback_wrapper.llm is not None + output = vllm_callback_wrapper.llm.generate("Vector Institute is") + print(output) diff --git a/examples/train_and_inference.py b/examples/train_and_inference.py new file mode 100644 index 0000000..305cd5a --- /dev/null +++ b/examples/train_and_inference.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import argparse +import math +import os +import sys +from argparse import Namespace + +import torch +import torch.distributed as dist +from torch.optim import AdamW +from tqdm import tqdm +from transformers import set_seed + +from vectorlm.dataset import Dataset +from vectorlm.trainer import Trainer +from vectorlm.utils.data_utils import Config +from vectorlm.utils.misc_utils import cleanup, setup, wandb_setup +from vectorlm.utils.model_utils import ( + get_lora_model_from_base_model, + get_submodule_by_pattern, + load_model_and_tokenizer, + shard_model, +) +from vectorlm.utils.optimizer_utils import get_custom_scheduler +from vectorlm.utils.save_utils import save_consolidated_model + + +def parse_args() -> Namespace: + """Parse command-line arguments. + + Returns + ------- + The parsed arguments. + + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--yaml_path", + default="configs/config.yaml", + required=False, + ) + return parser.parse_args() + + +def main(config: Config) -> None: + """Define the main calling function.""" + training_args = config.train_parameters + + # set a seed + set_seed(training_args.seed) + + # set CUDA related dependencies + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + print(f"Rank: {rank}, World size: {world_size}") + if dist.is_initialized(): + torch.cuda.set_device(local_rank) + torch.cuda.empty_cache() + + # setup wandb + if rank == 0: + wandb_setup(config, **config.wandb_config) + dist.barrier() + + # load model and tokenizer + model, tokenizer = load_model_and_tokenizer( + config.model, + training_args.use_mp, + training_args.use_flash_attention, + training_args.max_seq_len, + local_rank, + training_args.low_cpu_mem_usage, + ) + + lora_peft_config = getattr( + config.train_parameters, + "lora_peft_config", + None, + ) + if lora_peft_config is not None: + model = get_lora_model_from_base_model(model, lora_peft_config) + + decoder_layer_module = get_submodule_by_pattern(model, r"DecoderLayer$") + model = shard_model( + model.bfloat16(), + decoder_layer_module, + training_args.use_mp, + training_args.use_activation_checkpointing, + training_args.sharding_strategy, + local_rank, + training_args.low_cpu_mem_usage, + ) + + # load dataset + dataset = Dataset( + config=config.dataset, + tokenizer=tokenizer, + ) + + # instantiate trainer + trainer = Trainer( + config=training_args, + enable_wandb_logging=config.enable_wandb_logging, + original_dataset_length=dataset.original_length, + ) + + # load optimizer + optimizer = AdamW( + model.parameters(), + **training_args.optimizer, + ) + + # load lr scheduler + lr_scheduler = get_custom_scheduler( + training_args.lr_scheduler_type, + optimizer, + math.ceil( + trainer.num_update_steps_per_epoch * training_args.warmup_ratio, + ), + trainer.max_steps, + ) + + trainer.prepare_trainer( + model, + tokenizer, + dataset, + optimizer, + lr_scheduler, + ) + + # Checkpoint check. Always call before training. + # If no checkpoint, it returns 0. + checkpointed_epoch = trainer.find_checkpoint(training_args.output_dir) + + for epoch in range(checkpointed_epoch, training_args.epochs): + trainer.model.train() + train_dl_iterator = iter(dataset.train_dataloader) + for _ in tqdm( + range(len(dataset.train_dataloader)), + disable=rank != 0, + file=sys.__stdout__, + ): + batch = next(train_dl_iterator) + trainer.step(batch, epoch) + + if epoch == training_args.epochs - 1: + hf_save_dir = os.path.join(training_args.output_dir, "final-model") + else: + hf_save_dir = os.path.join( + training_args.output_dir, + "checkpoints", + f"epoch_{epoch}", + "end-epoch-model", + ) + save_consolidated_model(trainer.model, hf_save_dir, rank) + dataset.reset_dataloaders() + + +if __name__ == "__main__": + args = parse_args() + config = Config(yaml_path=args.yaml_path) + setup(config.train_parameters.output_dir) + main(config) + cleanup() diff --git a/profiling/launch_benchmark.py b/profiling/launch_benchmark.py index e9509ed..3f7fe51 100644 --- a/profiling/launch_benchmark.py +++ b/profiling/launch_benchmark.py @@ -50,7 +50,7 @@ } num_repeats = 2 -slurm_flags_extra = {"time": "01:00:00", "qos": qos_selected} +slurm_flags_extra = {"time": "02:00:00", "qos": qos_selected} slurm_pos_args_options = [ ["profiling/launch_benchmark.sh"], diff --git a/vectorlm/sampling/utils.py b/vectorlm/sampling/utils.py index 9bc3a11..4b3a56a 100644 --- a/vectorlm/sampling/utils.py +++ b/vectorlm/sampling/utils.py @@ -3,10 +3,36 @@ from __future__ import annotations import json +import os +import threading import time -from typing import Any, Iterable, NamedTuple - -from vllm import SamplingParams +from collections import Counter +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, Iterable, NamedTuple + +if TYPE_CHECKING: + from vllm import LLMEngine + from vllm.worker.worker_base import WorkerBase + +from vllm import LLM, SamplingParams +from vllm.engine.local_worker_utils import ( + LocalWorkerVllm, + ResultHandler, + WorkerMonitor, +) +from vllm.executor.multiproc_gpu_executor import ( + MultiProcGPUExecutor, + _create_worker, +) +from vllm.utils import get_distributed_init_method, set_cuda_visible_devices +from vllm.worker.worker import init_worker_distributed_environment + +if TYPE_CHECKING: + from threading import Barrier + + from vllm.engine.arg_utils import EngineConfig + + from vectorlm.utils import Config from .abstract import AbstractSamplingEngine @@ -24,6 +50,97 @@ class SampleOutput(NamedTuple): time_taken: float +def _ensure_torch_dist_is_initialized() -> None: + import torch.distributed + + assert torch.distributed.is_initialized() + + +def _get_rdvz_url() -> str: + """Obtain rendezvous url for Torch dist.""" + return get_distributed_init_method( + os.environ.get("MASTER_ADDR", "127.0.0.1"), + int(os.environ["MASTER_PORT"]), + ) + + +class ManagedMultiProcGPUExecutor(MultiProcGPUExecutor): + """MultiProcGPUExecutor, but with worker processes instantiated outside.""" + + workers: tuple[LocalWorkerVllm, ...] | None = None + vectorlm_main_fn: Callable[[], None] | None = None + vectorlm_dist_init_barrier: Barrier | None = None + + def _init_executor(self) -> None: + """Initialize executor without initializing workers. + + Same as MultiProcGPUExecutor but assumes self.workers is already set. + + Mostly reproduced from + vllm/vllm-ray-optional/vllm/executor/multiproc_gpu_executor.py + """ + assert ( + not self.speculative_config + ), "Speculative decoding not yet supported for MultiProcGPU backend." + + # Create the parallel GPU workers. + world_size = self.parallel_config.tensor_parallel_size + assert self.workers is not None + assert len(self.workers) == world_size - 1, ( + f"non-driver workers len(self.workers): {len(self.workers)} " + f"should be (world_size - 1) {world_size - 1}" + ) + + if "CUDA_VISIBLE_DEVICES" not in os.environ: + set_cuda_visible_devices(range(world_size)) + + from torch.cuda import device_count + + assert ( + world_size <= device_count() + ), "please set tensor_parallel_size to less than max local gpu count" + + result_handler = ResultHandler() + self.worker_monitor = WorkerMonitor( + list(self.workers), + result_handler, + ) + result_handler.start() + self.worker_monitor.start() + + distributed_init_method = _get_rdvz_url() + + # driver worker is of rank 0 + print("driver worker: init_worker_dist started") + init_worker_distributed_environment( + self.parallel_config, + 0, + distributed_init_method, + 0, + ) + print("driver worker: init_worker_dist completed") + _ensure_torch_dist_is_initialized() + + # start vectorlm logic in the same Python process + # (albeit in a separate thread) + vectorlm_thread = threading.Thread( + target=self.vectorlm_main_fn, + name="driver/vectorlm", + ) + vectorlm_thread.start() + + self._init_driver_worker_and_model(0, 0, distributed_init_method) + + +class ManagedLLM(LLM): + """vllm.entrypoints.LLM but using an externally-initialized LLM Engine.""" + + def __init__(self, llm_engine: LLMEngine) -> None: + """Instantiate LLM instance using externally-initialized LLM Engine.""" + self.llm_engine = llm_engine + self.request_counter = Counter() + + def handle_sample( sampling_engine: AbstractSamplingEngine, prompts: Iterable[str], @@ -74,3 +191,25 @@ def handle_sample( output_jsonl_file.write("\n".join(jsonl_output_lines) + "\n\n") return sample_outputs + + +def get_vllm_worker_factory( + engine_config: EngineConfig, + distributed_init_method: str, + rank: int, +) -> Callable[[], WorkerBase]: + """Initialize vLLM worker.""" + return partial( + _create_worker, + model_config=engine_config.model_config, + parallel_config=engine_config.parallel_config, + scheduler_config=engine_config.scheduler_config, + device_config=engine_config.device_config, + cache_config=engine_config.cache_config, + local_rank=rank, + rank=rank, + distributed_init_method=distributed_init_method, + lora_config=engine_config.lora_config, + vision_language_config=engine_config.vision_language_config, + tensorizer_config=engine_config.tensorizer_config, + ) From ebb7bc9103612758d47e410eb3b46a5dcbf94d65 Mon Sep 17 00:00:00 2001 From: jacobthebanana Date: Mon, 6 May 2024 09:47:11 -0400 Subject: [PATCH 72/89] vllm integration: Eliminated duplicate vllm ResultHandler. --- examples/llama_example.py | 14 +++++-- examples/llama_example_mp.py | 74 +++++++++++++++++++++++++++++------- vectorlm/sampling/utils.py | 23 ++++++----- 3 files changed, 83 insertions(+), 28 deletions(-) diff --git a/examples/llama_example.py b/examples/llama_example.py index fb6539f..65058cc 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -6,8 +6,7 @@ import os import sys from argparse import Namespace -from threading import Barrier -from typing import Callable +from typing import TYPE_CHECKING, Callable import torch import torch.distributed as dist @@ -33,6 +32,9 @@ save_peft_adapter, ) +if TYPE_CHECKING: + from threading import Barrier + def parse_args() -> Namespace: """Parse command-line arguments. @@ -55,10 +57,16 @@ def main( config: Config, local_rank: int | None = None, world_size: int | None = None, - dist_init_barrier: Barrier | None = None, + vllm_init_barrier: Barrier | None = None, vllm_init_callback: Callable[[], None] | None = None, ) -> None: """Define the main calling function.""" + if vllm_init_barrier is not None: + # Wait until vllm engine is ready. + print(f"rank {local_rank} vllm_init_barrier wait") + vllm_init_barrier.wait() + print(f"rank {local_rank} vllm_init_barrier cleared") + training_args = config.train_parameters # set a seed diff --git a/examples/llama_example_mp.py b/examples/llama_example_mp.py index 15de64e..d7968e7 100644 --- a/examples/llama_example_mp.py +++ b/examples/llama_example_mp.py @@ -1,4 +1,17 @@ -"""llama_example, but uses multiprocessing in place of torchrun""" +"""llama_example, but uses multiprocessing in place of torchrun. + +Each non-rank-0 worker process should spawn vectorlm logic in a +separate thread (but same process) but won't run the actual +vectorlm logic until the vLLM Engine is initialized- inference +weights loaded into each worker. + +To do so without duplicating vLLM code, observe that only the main process +(rank 0) is aware that vLLM engine was initialized properly +(when LLMEngine.__init__ returns.) Hence, one way to implement this +setup is to block the vectorlm thread with a multiprocessing synchronization +feature (e.g., a Barrier shared across all processes) that the rank 0 process +can remotely unblock. +""" from __future__ import annotations @@ -32,8 +45,6 @@ logging.basicConfig(level=logging.DEBUG) -mp = multiprocessing.get_context("fork") - class _VLLMCallbackWrapper: """Provide vLLM Engine access to multiprocess.Process workers. @@ -44,26 +55,31 @@ class _VLLMCallbackWrapper: def __init__( self, non_driver_workers: list[VectorLMWorker], + vllm_result_handler: ResultHandler, engine_config: EngineConfig, vectorlm_config: Config, world_size: int, + vllm_init_barrier: threading.Barrier, ) -> None: """Instantiate class without initializing wrapped vLLM engine.""" self.llm_engine: LLMEngine | None = None self.llm: LLM | None = None self.non_driver_workers = non_driver_workers + self.vllm_result_handler = vllm_result_handler self.engine_config = engine_config + self.vllm_init_barrier = vllm_init_barrier + + # Might not be required since LLM.generate is blocking. + # torch.dist.barrier might be sufficient for blocking + # other worker processes. + self.gpu_access_lock = threading.Lock() - # torch.dist init barrier for rank 0 vectorlm process. - # ensures rank 0 vectorlm achieves torch.dist - # before starting rank 0 Worker. - self.root_vectorlm_dist_init_barrier = threading.Barrier(2) self.vectorlm_main_fn = partial( main, vectorlm_config, 0, world_size, - self.root_vectorlm_dist_init_barrier, + self.vllm_init_barrier, ) def initialize_engine(self) -> None: @@ -75,18 +91,39 @@ def initialize_engine(self) -> None: self.non_driver_workers, ) ManagedMultiProcGPUExecutor.vectorlm_main_fn = self.vectorlm_main_fn - ManagedMultiProcGPUExecutor.vectorlm_dist_init_barrier = ( - self.root_vectorlm_dist_init_barrier - ) + ManagedMultiProcGPUExecutor.result_handler = self.vllm_result_handler self.llm_engine = LLMEngine( **self.engine_config.to_dict(), executor_class=ManagedMultiProcGPUExecutor, log_stats=False, ) + print("main: vllm_init_barrier waiting") + self.vllm_init_barrier.wait() + print("main: vllm_init_barrier cleared") + self.llm = ManagedLLM(self.llm_engine) print(f"Instantiated ManagedLLM: {self.llm}") + def get_engine(self) -> None: + """Return LLM instance. + + Invoke this method only within the main (rank 0 driver) process. + """ + if self.llm is None: + self.initialize_engine() + + def join_vectorlm_thread(self) -> None: + """Join the rank 0 (main process) vectorlm thread. + + Invoke this function only after initialize_engine. + """ + assert self.llm_engine is not None + model_executor = self.llm_engine.model_executor + assert isinstance(model_executor, ManagedMultiProcGPUExecutor) + + model_executor.vectorlm_thread.join() + class VectorLMWorker(LocalWorkerVllm): """Worker for running VectorLM logic alongside vLLM worker. @@ -109,6 +146,7 @@ def __init__( vectorlm_config: Config, local_rank: int, world_size: int, + vllm_init_barrier: threading.Barrier, ) -> None: """Instantiate LocalWorkerVllm wrapper. @@ -116,10 +154,12 @@ def __init__( the vectorlm thread and not the main thread (vllm) of the process. """ self.vllm_engine_config = vllm_engine_config - self.vectorlm_dist_init_barrier = threading.Barrier(2) + self.gpu_access_lock = threading.Lock() self.vectorlm_config = vectorlm_config self.local_rank = local_rank self.world_size = world_size + self.vllm_init_barrier = vllm_init_barrier + self.vllm_init_callback: Callable[[], None] | None = None super().__init__(result_handler, worker_factory) @@ -143,7 +183,7 @@ def run(self) -> None: self.vectorlm_config, self.local_rank, self.world_size, - self.vectorlm_dist_init_barrier, + self.vllm_init_barrier, self.vllm_init_callback, ), name=f"rank-{self.local_rank}/vectorlm", @@ -172,6 +212,9 @@ def run(self) -> None: dtype=sampler_config.get("vllm_dtype", "auto"), ).create_engine_config() + # Block all N vectorlm threads until main process finished + # initializing vLLM Engine. + vllm_init_barrier = multiprocessing.Barrier(world_size + 1) vllm_result_handler = ResultHandler() # rank 0 worker runs in the __main__ process. @@ -189,14 +232,17 @@ def run(self) -> None: vectorlm_config, local_rank, world_size=world_size, + vllm_init_barrier=vllm_init_barrier, ) for local_rank in range(1, world_size) ] vllm_callback_wrapper = _VLLMCallbackWrapper( non_driver_workers, + vllm_result_handler, vllm_engine_config, vectorlm_config, world_size, + vllm_init_barrier, ) for worker in non_driver_workers: @@ -206,3 +252,5 @@ def run(self) -> None: assert vllm_callback_wrapper.llm is not None output = vllm_callback_wrapper.llm.generate("Vector Institute is") print(output) + + vllm_callback_wrapper.join_vectorlm_thread() diff --git a/vectorlm/sampling/utils.py b/vectorlm/sampling/utils.py index 4b3a56a..255e90a 100644 --- a/vectorlm/sampling/utils.py +++ b/vectorlm/sampling/utils.py @@ -6,7 +6,6 @@ import os import threading import time -from collections import Counter from functools import partial from typing import TYPE_CHECKING, Any, Callable, Iterable, NamedTuple @@ -24,16 +23,16 @@ MultiProcGPUExecutor, _create_worker, ) -from vllm.utils import get_distributed_init_method, set_cuda_visible_devices +from vllm.utils import ( + Counter, + get_distributed_init_method, + set_cuda_visible_devices, +) from vllm.worker.worker import init_worker_distributed_environment if TYPE_CHECKING: - from threading import Barrier - from vllm.engine.arg_utils import EngineConfig - from vectorlm.utils import Config - from .abstract import AbstractSamplingEngine @@ -69,7 +68,7 @@ class ManagedMultiProcGPUExecutor(MultiProcGPUExecutor): workers: tuple[LocalWorkerVllm, ...] | None = None vectorlm_main_fn: Callable[[], None] | None = None - vectorlm_dist_init_barrier: Barrier | None = None + result_handler: ResultHandler | None = None def _init_executor(self) -> None: """Initialize executor without initializing workers. @@ -100,12 +99,12 @@ def _init_executor(self) -> None: world_size <= device_count() ), "please set tensor_parallel_size to less than max local gpu count" - result_handler = ResultHandler() + assert self.result_handler is not None self.worker_monitor = WorkerMonitor( list(self.workers), - result_handler, + self.result_handler, ) - result_handler.start() + self.result_handler.start() self.worker_monitor.start() distributed_init_method = _get_rdvz_url() @@ -123,11 +122,11 @@ def _init_executor(self) -> None: # start vectorlm logic in the same Python process # (albeit in a separate thread) - vectorlm_thread = threading.Thread( + self.vectorlm_thread = threading.Thread( target=self.vectorlm_main_fn, name="driver/vectorlm", ) - vectorlm_thread.start() + self.vectorlm_thread.start() self._init_driver_worker_and_model(0, 0, distributed_init_method) From 1f1f88e74c11e686295a551e652bb84802be2efc Mon Sep 17 00:00:00 2001 From: jacobthebanana Date: Mon, 6 May 2024 10:57:51 -0400 Subject: [PATCH 73/89] vllm integration [WIP]: Revised vectorlm-vllm concurrency handling. --- examples/llama_example.py | 48 ++++++++++++++++++++++-------------- examples/llama_example_mp.py | 19 +++++++++++--- 2 files changed, 45 insertions(+), 22 deletions(-) diff --git a/examples/llama_example.py b/examples/llama_example.py index 65058cc..559c754 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -35,6 +35,8 @@ if TYPE_CHECKING: from threading import Barrier + from vllm import LLM + def parse_args() -> Namespace: """Parse command-line arguments. @@ -58,7 +60,7 @@ def main( local_rank: int | None = None, world_size: int | None = None, vllm_init_barrier: Barrier | None = None, - vllm_init_callback: Callable[[], None] | None = None, + get_vllm_engine: Callable[[], LLM] | None = None, ) -> None: """Define the main calling function.""" if vllm_init_barrier is not None: @@ -68,10 +70,15 @@ def main( print(f"rank {local_rank} vllm_init_barrier cleared") training_args = config.train_parameters + sampler_config = training_args.get("sampler") # set a seed set_seed(training_args.seed) + if dist.is_initialized(): + torch.cuda.set_device(local_rank) + torch.cuda.empty_cache() + # set CUDA related dependencies if (local_rank is None) or (world_size is None): local_rank = int(os.environ["LOCAL_RANK"]) @@ -83,15 +90,7 @@ def main( os.environ["RANK"] = str(local_rank) os.environ["WORLD_SIZE"] = str(world_size) - logging.info( - "dist.init_process_group", - extra={"local_rank": local_rank, "world_size": world_size}, - ) - print(f"Rank: {rank}, World size: {world_size}") - if dist.is_initialized(): - torch.cuda.set_device(local_rank) - torch.cuda.empty_cache() # setup wandb if rank == 0: @@ -112,7 +111,6 @@ def main( "lora_peft_config", None, ) - sampler_config = config.train_parameters.get("sampler") is_peft_adapter_restored = False if lora_peft_config is not None: peft_adapter_path = None @@ -146,6 +144,10 @@ def main( is_lora_enabled=(lora_peft_config is not None), ) + if vllm_init_barrier is not None: + vllm_init_barrier.wait() + vllm_init_barrier.wait() + # load dataset dataset = Dataset( config=config.dataset, @@ -188,6 +190,16 @@ def main( # If no checkpoint, it returns 0. checkpointed_epoch = trainer.find_checkpoint(training_args.output_dir) + if sampler_config is not None: + if dist.get_rank() == 0: + assert get_vllm_engine is not None + vllm_llm = get_vllm_engine() + output = vllm_llm.generate("Vector Institute is") + print(output) + + print(f"rank {rank}: llm.generate barrier") + dist.barrier() + for epoch in range(checkpointed_epoch, training_args.epochs): train_dl_iterator = iter(dataset.train_dataloader) for index in tqdm( @@ -198,15 +210,15 @@ def main( batch = next(train_dl_iterator) trainer.step(batch, epoch) - # if ( - # (sampler_config is not None) - # and (index % training_args.sampler.sample_frequency == 0) - # and (dist.get_rank() == 0) - # ): - # output = vllm_llm.generate(training_args.sampler.prompts) - # print(output) + if ( + (sampler_config is not None) + and (index % training_args.sampler.sample_frequency == 0) + and (dist.get_rank() == 0) + ): + output = vllm_llm.generate(training_args.sampler.prompts[0]) + print(output) - # dist.barrier() + dist.barrier() if epoch == training_args.epochs - 1: hf_save_dir = os.path.join(training_args.output_dir, "final-model") diff --git a/examples/llama_example_mp.py b/examples/llama_example_mp.py index d7968e7..654be76 100644 --- a/examples/llama_example_mp.py +++ b/examples/llama_example_mp.py @@ -11,6 +11,9 @@ setup is to block the vectorlm thread with a multiprocessing synchronization feature (e.g., a Barrier shared across all processes) that the rank 0 process can remotely unblock. + +Edit: It seems that vllm.entrypoint.llm.LLM generate calls aren't +entirely blocking. """ from __future__ import annotations @@ -80,6 +83,7 @@ def __init__( 0, world_size, self.vllm_init_barrier, + self.get_engine, ) def initialize_engine(self) -> None: @@ -98,14 +102,11 @@ def initialize_engine(self) -> None: executor_class=ManagedMultiProcGPUExecutor, log_stats=False, ) - print("main: vllm_init_barrier waiting") - self.vllm_init_barrier.wait() - print("main: vllm_init_barrier cleared") self.llm = ManagedLLM(self.llm_engine) print(f"Instantiated ManagedLLM: {self.llm}") - def get_engine(self) -> None: + def get_engine(self) -> LLM: """Return LLM instance. Invoke this method only within the main (rank 0 driver) process. @@ -113,6 +114,10 @@ def get_engine(self) -> None: if self.llm is None: self.initialize_engine() + llm = self.llm + assert llm is not None + return llm + def join_vectorlm_thread(self) -> None: """Join the rank 0 (main process) vectorlm thread. @@ -250,6 +255,12 @@ def run(self) -> None: vllm_callback_wrapper.initialize_engine() assert vllm_callback_wrapper.llm is not None + print("main: vllm_init_barrier waiting") + vllm_init_barrier.wait() + print("main: vllm_init_barrier cleared") + + vllm_init_barrier.wait() + output = vllm_callback_wrapper.llm.generate("Vector Institute is") print(output) From 11a1ba598a72139b07e0bf1085274a1c819a8618 Mon Sep 17 00:00:00 2001 From: jacobthebanana Date: Mon, 6 May 2024 14:28:19 -0400 Subject: [PATCH 74/89] vllm integration [WIP]: Implemented inference during training. --- examples/llama_example.py | 61 +++++++++++++++++++----------------- examples/llama_example_mp.py | 59 +++++++++++++++++++++++----------- vectorlm/sampling/utils.py | 22 +++++++++++++ 3 files changed, 95 insertions(+), 47 deletions(-) diff --git a/examples/llama_example.py b/examples/llama_example.py index 559c754..10fce0b 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -1,7 +1,6 @@ from __future__ import annotations import argparse -import logging import math import os import sys @@ -35,7 +34,9 @@ if TYPE_CHECKING: from threading import Barrier - from vllm import LLM + from vllm import LLM, RequestOutput + + from vectorlm.sampling.utils import SynchronizationBarriers def parse_args() -> Namespace: @@ -59,14 +60,24 @@ def main( config: Config, local_rank: int | None = None, world_size: int | None = None, - vllm_init_barrier: Barrier | None = None, - get_vllm_engine: Callable[[], LLM] | None = None, + barriers: SynchronizationBarriers | None = None, + get_vllm_llm: Callable[[], LLM] | None = None, ) -> None: - """Define the main calling function.""" - if vllm_init_barrier is not None: - # Wait until vllm engine is ready. + """Define the main calling function. + + Args: + ---- + config: vectorlm config, e.g., loaded from yaml + local_rank: int, where 0 is root process, one process per accelerator. + world_size: number of processes. + barriers: SynchronizationBarriers, required for all processes. + get_vllm_llm: required only for root process (rank 0). + + """ + if barriers is not None: + # Wait until vllm engine is fully initialized. print(f"rank {local_rank} vllm_init_barrier wait") - vllm_init_barrier.wait() + barriers.vllm_init.wait() print(f"rank {local_rank} vllm_init_barrier cleared") training_args = config.train_parameters @@ -144,10 +155,6 @@ def main( is_lora_enabled=(lora_peft_config is not None), ) - if vllm_init_barrier is not None: - vllm_init_barrier.wait() - vllm_init_barrier.wait() - # load dataset dataset = Dataset( config=config.dataset, @@ -190,16 +197,6 @@ def main( # If no checkpoint, it returns 0. checkpointed_epoch = trainer.find_checkpoint(training_args.output_dir) - if sampler_config is not None: - if dist.get_rank() == 0: - assert get_vllm_engine is not None - vllm_llm = get_vllm_engine() - output = vllm_llm.generate("Vector Institute is") - print(output) - - print(f"rank {rank}: llm.generate barrier") - dist.barrier() - for epoch in range(checkpointed_epoch, training_args.epochs): train_dl_iterator = iter(dataset.train_dataloader) for index in tqdm( @@ -210,13 +207,21 @@ def main( batch = next(train_dl_iterator) trainer.step(batch, epoch) - if ( - (sampler_config is not None) - and (index % training_args.sampler.sample_frequency == 0) - and (dist.get_rank() == 0) + if (sampler_config is not None) and ( + index % training_args.sampler.sample_frequency == 0 ): - output = vllm_llm.generate(training_args.sampler.prompts[0]) - print(output) + assert barriers is not None + barriers.before_generation.wait() + + if dist.get_rank() == 0: + assert get_vllm_llm is not None + # the line below should block until vllm finishes running. + output = get_vllm_llm().generate( + training_args.sampler.prompts, + ) + print(output) + + barriers.after_generation.wait() dist.barrier() diff --git a/examples/llama_example_mp.py b/examples/llama_example_mp.py index 654be76..5ee0056 100644 --- a/examples/llama_example_mp.py +++ b/examples/llama_example_mp.py @@ -24,12 +24,13 @@ import multiprocessing.context import multiprocessing.managers import threading -from functools import partial +from functools import partial, wraps from typing import TYPE_CHECKING, Callable if TYPE_CHECKING: from vllm.worker.worker_base import WorkerBase +import vllm from llama_example import main from vllm.engine.arg_utils import EngineArgs, EngineConfig from vllm.engine.llm_engine import LLMEngine @@ -40,6 +41,7 @@ from vectorlm.sampling.utils import ( ManagedLLM, ManagedMultiProcGPUExecutor, + SynchronizationBarriers, _ensure_torch_dist_is_initialized, _get_rdvz_url, get_vllm_worker_factory, @@ -62,7 +64,7 @@ def __init__( engine_config: EngineConfig, vectorlm_config: Config, world_size: int, - vllm_init_barrier: threading.Barrier, + barriers: SynchronizationBarriers, ) -> None: """Instantiate class without initializing wrapped vLLM engine.""" self.llm_engine: LLMEngine | None = None @@ -70,7 +72,7 @@ def __init__( self.non_driver_workers = non_driver_workers self.vllm_result_handler = vllm_result_handler self.engine_config = engine_config - self.vllm_init_barrier = vllm_init_barrier + self.barriers = barriers # Might not be required since LLM.generate is blocking. # torch.dist.barrier might be sufficient for blocking @@ -82,8 +84,8 @@ def __init__( vectorlm_config, 0, world_size, - self.vllm_init_barrier, - self.get_engine, + self.barriers, + self.get_vllm_llm, ) def initialize_engine(self) -> None: @@ -106,7 +108,24 @@ def initialize_engine(self) -> None: self.llm = ManagedLLM(self.llm_engine) print(f"Instantiated ManagedLLM: {self.llm}") - def get_engine(self) -> LLM: + @wraps(LLM.generate) + def generate( + self, + *args, # noqa: ANN002 + **kwargs, # noqa: ANN003 + ) -> list[vllm.RequestOutput]: + """Invoke self.llm.generate. + + All args and kwargs are forwarded to llm.generate. + + Before invoking this method, make sure no other vectorlm threads + is using the GPU. This method blocks until vLLM finishes running + completely. + """ + assert self.llm is not None + return self.llm.generate(*args, **kwargs) + + def get_vllm_llm(self) -> LLM: """Return LLM instance. Invoke this method only within the main (rank 0 driver) process. @@ -151,7 +170,7 @@ def __init__( vectorlm_config: Config, local_rank: int, world_size: int, - vllm_init_barrier: threading.Barrier, + barriers: SynchronizationBarriers, ) -> None: """Instantiate LocalWorkerVllm wrapper. @@ -163,9 +182,7 @@ def __init__( self.vectorlm_config = vectorlm_config self.local_rank = local_rank self.world_size = world_size - self.vllm_init_barrier = vllm_init_barrier - - self.vllm_init_callback: Callable[[], None] | None = None + self.barriers = barriers super().__init__(result_handler, worker_factory) @@ -188,8 +205,7 @@ def run(self) -> None: self.vectorlm_config, self.local_rank, self.world_size, - self.vllm_init_barrier, - self.vllm_init_callback, + self.barriers, ), name=f"rank-{self.local_rank}/vectorlm", ) @@ -218,8 +234,15 @@ def run(self) -> None: ).create_engine_config() # Block all N vectorlm threads until main process finished - # initializing vLLM Engine. - vllm_init_barrier = multiprocessing.Barrier(world_size + 1) + # initializing vLLM Engine. Additionally, block vectorlm + # threads as long as vLLM tasks are running. + barriers = SynchronizationBarriers( + # (n+1) threads: __main__, and n vectorlm threads (including main). + multiprocessing.Barrier(world_size + 1), + # n vectorlm threads. + multiprocessing.Barrier(world_size), + multiprocessing.Barrier(world_size), + ) vllm_result_handler = ResultHandler() # rank 0 worker runs in the __main__ process. @@ -237,7 +260,7 @@ def run(self) -> None: vectorlm_config, local_rank, world_size=world_size, - vllm_init_barrier=vllm_init_barrier, + barriers=barriers, ) for local_rank in range(1, world_size) ] @@ -247,7 +270,7 @@ def run(self) -> None: vllm_engine_config, vectorlm_config, world_size, - vllm_init_barrier, + barriers, ) for worker in non_driver_workers: @@ -256,11 +279,9 @@ def run(self) -> None: vllm_callback_wrapper.initialize_engine() assert vllm_callback_wrapper.llm is not None print("main: vllm_init_barrier waiting") - vllm_init_barrier.wait() + barriers.vllm_init.wait() print("main: vllm_init_barrier cleared") - vllm_init_barrier.wait() - output = vllm_callback_wrapper.llm.generate("Vector Institute is") print(output) diff --git a/vectorlm/sampling/utils.py b/vectorlm/sampling/utils.py index 255e90a..6f5b441 100644 --- a/vectorlm/sampling/utils.py +++ b/vectorlm/sampling/utils.py @@ -31,6 +31,8 @@ from vllm.worker.worker import init_worker_distributed_environment if TYPE_CHECKING: + from threading import Barrier + from vllm.engine.arg_utils import EngineConfig from .abstract import AbstractSamplingEngine @@ -49,6 +51,26 @@ class SampleOutput(NamedTuple): time_taken: float +class SynchronizationBarriers(NamedTuple): + """Barriers for limiting GPU access concurrency. + + Params: + vllm_init: Barrier to Ensures that vLLM engine is fully initialized + before running any vectorlm logic. + + before_generation: Ensure all processes have reached this statement, + or vectorlm in some processes might still be accessing the + accelerator when rank 0 invokes vLLM. + + after_generation: Detain all processes until after rank 0 is sure that + there are no outstanding vLLM jobs. + """ + + vllm_init: Barrier + before_generation: Barrier + after_generation: Barrier + + def _ensure_torch_dist_is_initialized() -> None: import torch.distributed From b697dc0909cadbbf3a47b2e861f49ff22beedfc5 Mon Sep 17 00:00:00 2001 From: jacobthebanana Date: Tue, 7 May 2024 11:41:26 -0400 Subject: [PATCH 75/89] vllm integration [WIP]: Implemented lora hotswap. Still need to move barrier logic into _VLLMCallbackWrapper. --- examples/llama_example.py | 35 +++++++------ examples/llama_example_mp.py | 12 +++-- vectorlm/sampling/abstract.py | 19 +++++-- vectorlm/sampling/sampling_lora.py | 84 ++++++++++++++++-------------- vectorlm/utils/save_utils.py | 3 +- 5 files changed, 87 insertions(+), 66 deletions(-) diff --git a/examples/llama_example.py b/examples/llama_example.py index 10fce0b..f2f4b8e 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -12,8 +12,10 @@ from torch.optim import AdamW from tqdm import tqdm from transformers import set_seed +from vllm import SamplingParams from vectorlm.dataset import Dataset +from vectorlm.sampling import LoRASamplingEngine from vectorlm.trainer import Trainer from vectorlm.utils.data_utils import Config from vectorlm.utils.misc_utils import cleanup, setup, wandb_setup @@ -32,9 +34,7 @@ ) if TYPE_CHECKING: - from threading import Barrier - - from vllm import LLM, RequestOutput + from vllm import LLM from vectorlm.sampling.utils import SynchronizationBarriers @@ -104,7 +104,7 @@ def main( print(f"Rank: {rank}, World size: {world_size}") # setup wandb - if rank == 0: + if rank == 0 and config.enable_wandb_logging: wandb_setup(config, **config.wandb_config) # load model and tokenizer @@ -197,6 +197,16 @@ def main( # If no checkpoint, it returns 0. checkpointed_epoch = trainer.find_checkpoint(training_args.output_dir) + if sampler_config is not None: + vllm_llm = get_vllm_llm() if get_vllm_llm is not None else None + print("Initializing sampling_engine") + sampling_engine = LoRASamplingEngine( + trainer, + vllm_llm, # required only for rank 0 + SamplingParams(seed=0, temperature=0), + barriers, + ) + for epoch in range(checkpointed_epoch, training_args.epochs): train_dl_iterator = iter(dataset.train_dataloader) for index in tqdm( @@ -210,18 +220,9 @@ def main( if (sampler_config is not None) and ( index % training_args.sampler.sample_frequency == 0 ): - assert barriers is not None - barriers.before_generation.wait() - - if dist.get_rank() == 0: - assert get_vllm_llm is not None - # the line below should block until vllm finishes running. - output = get_vllm_llm().generate( - training_args.sampler.prompts, - ) - print(output) - - barriers.after_generation.wait() + sampling_engine.update(trainer) + output = sampling_engine.generate(sampler_config.prompts) + print(output[0].prompt + output[0].outputs[0].text) dist.barrier() @@ -242,6 +243,8 @@ def main( dataset.reset_dataloaders() + sys.exit(0) + if __name__ == "__main__": args = parse_args() diff --git a/examples/llama_example_mp.py b/examples/llama_example_mp.py index 5ee0056..c5521f6 100644 --- a/examples/llama_example_mp.py +++ b/examples/llama_example_mp.py @@ -19,7 +19,6 @@ from __future__ import annotations import argparse -import logging import multiprocessing import multiprocessing.context import multiprocessing.managers @@ -48,8 +47,6 @@ ) from vectorlm.utils.data_utils import Config -logging.basicConfig(level=logging.DEBUG) - class _VLLMCallbackWrapper: """Provide vLLM Engine access to multiprocess.Process workers. @@ -121,6 +118,12 @@ def generate( Before invoking this method, make sure no other vectorlm threads is using the GPU. This method blocks until vLLM finishes running completely. + + Note that it might be more elegant to use generate instead of + directly invoking LLM.generate, so that this implementation can + handle the broadcasting and synchronization safely. However, + doing so would prevent some IDE from inferring the argument types + correctly. """ assert self.llm is not None return self.llm.generate(*args, **kwargs) @@ -231,6 +234,7 @@ def run(self) -> None: ), tensor_parallel_size=world_size, dtype=sampler_config.get("vllm_dtype", "auto"), + enable_lora=True, ).create_engine_config() # Block all N vectorlm threads until main process finished @@ -283,6 +287,6 @@ def run(self) -> None: print("main: vllm_init_barrier cleared") output = vllm_callback_wrapper.llm.generate("Vector Institute is") - print(output) + print(output[0].prompt + output[0].outputs[0].text) vllm_callback_wrapper.join_vectorlm_thread() diff --git a/vectorlm/sampling/abstract.py b/vectorlm/sampling/abstract.py index 51ed6f8..961f350 100644 --- a/vectorlm/sampling/abstract.py +++ b/vectorlm/sampling/abstract.py @@ -3,11 +3,15 @@ from __future__ import annotations from abc import ABC, abstractmethod +from typing import TYPE_CHECKING import vllm from vectorlm.trainer import Trainer +if TYPE_CHECKING: + from .utils import SynchronizationBarriers + class AbstractSamplingEngine(ABC): """Interface for the sampling engine.""" @@ -15,17 +19,24 @@ class AbstractSamplingEngine(ABC): def __init__( self, trainer: Trainer, + vllm_llm: vllm.LLM | None = None, sampling_params: vllm.SamplingParams | None = None, + synchronization_barriers: SynchronizationBarriers | None = None, ) -> None: """Initialize sampling engine. Params: trainer: Trainer instance. + vllm_llm: Instance of vllm.LLM, required only for rank 0. sampling_params: Optionally, specify default sampling params. + synchronization_barriers: Optionally, supply barriers to + prevent workers from accessing GPU while vLLM is running. """ self.trainer = trainer + self.vllm_llm = vllm_llm self.sampling_params = sampling_params + self.synchronization_barriers = synchronization_barriers def update(self, trainer: Trainer | None = None) -> None: """Inform the sampling engine that the model in trainer is updated. @@ -41,9 +52,11 @@ def generate( self, prompts: list[str], sampling_params: vllm.SamplingParams | None = None, - ) -> list[list[vllm.CompletionOutput]]: + ) -> list[vllm.RequestOutput]: """Generate continuation for the given prompts synchronously. + Invoke at all ranks. Output will be broadcasted to all ranks. + Params: ------ prompts: List of input prompts. @@ -52,8 +65,6 @@ def generate( Returns ------- - Output from vllm: list[list[vllm.CompletionOutput]] - outer layer: one for each prompt. - inner layer: one for each output option for the prompt. + Output from vllm: list[vllm.RequestOutput], one for each prompt. """ diff --git a/vectorlm/sampling/sampling_lora.py b/vectorlm/sampling/sampling_lora.py index d4cc00d..45d7b99 100644 --- a/vectorlm/sampling/sampling_lora.py +++ b/vectorlm/sampling/sampling_lora.py @@ -11,6 +11,7 @@ from vectorlm.utils.save_utils import save_peft_adapter from .abstract import AbstractSamplingEngine +from .utils import SynchronizationBarriers class LoRASamplingEngine(AbstractSamplingEngine): @@ -19,26 +20,22 @@ class LoRASamplingEngine(AbstractSamplingEngine): def __init__( self, trainer: Trainer, + vllm_llm: vllm.LLM | None = None, sampling_params: vllm.SamplingParams | None = None, - base_model_name: str | None = None, - tensor_parallel_size: int = 1, - gpu_memory_utilization: float = 0.3, + synchronization_barriers: SynchronizationBarriers | None = None, adapter_temp_folder: str | None = None, ) -> None: """Initialize sampling engine. Params: trainer: Trainer instance. + vllm_llm: Instance of vllm.LLM, required only for rank 0. sampling_params: Optionally, specify default sampling params. - base_model_name: Path or HuggingFace repo name of base model. - tensor_parallel_size: Forwarded to vllm.LLM. - gpu_memory_utilization: Forwarded to vllm.LLM. adapter_temp_folder: Temporary path where temporary adapter weights are saved. If not specified, f`/dev/shm/{job_id}` """ - if dist.get_rank() != 0: - return - + assert synchronization_barriers is not None + self.barriers = synchronization_barriers self.sampling_params = sampling_params if adapter_temp_folder is not None: @@ -54,18 +51,12 @@ def __init__( slurm_job_id_or_placeholder, ) - assert ( - base_model_name is not None - ), "base_model_name is required when instantiating LoRASamplingEngine." - - self.vllm_llm = vllm.LLM( - base_model_name, - tensor_parallel_size=tensor_parallel_size, - gpu_memory_utilization=gpu_memory_utilization, - enable_lora=True, - ) + if dist.get_rank() == 0: + assert vllm_llm is not None + self.vllm_llm = vllm_llm - # Trigger FSDP initialization before + # Trigger FSDP initialization before retrieving weights. + # Otherwise FSDP is_root flag might be set incorrectly. _wrapped_model = trainer.model assert _wrapped_model is not None _wrapped_model(input_ids=torch.zeros((1, 1), dtype=torch.int)) @@ -79,9 +70,6 @@ def update(self, trainer: Trainer | None = None) -> None: Params: trainer: Optionally, replace self.trainer with the provided value. """ - if dist.get_rank() != 0: - return - if trainer is not None: self.trainer = trainer @@ -103,8 +91,10 @@ def generate( self, prompts: list[str], sampling_params: vllm.SamplingParams | None = None, - ) -> list[list[vllm.CompletionOutput]]: - """Generate continuation for the given prompts. Invoke only on rank 0. + ) -> list[vllm.RequestOutput]: + """Generate continuation for the given prompts. Invoke at all ranks. + + Output will be broadcasted to all ranks. Params: ------ @@ -114,20 +104,34 @@ def generate( Returns ------- - Output from vllm: list[list[vllm.CompletionOutput]] - outer layer: one for each prompt. - inner layer: one for each output option for the prompt. + Output from vllm: list[vllm.RequestOutput], one for each prompt. """ - if dist.get_rank() != 0: - msg = "LoRA sampling engine is supported only on rank 0." - raise RuntimeError(msg) - - assert self.vllm_train_step is not None - output_list = self.vllm_llm.generate( - prompts, - sampling_params, - lora_request=self.lora_request, - use_tqdm=True, - ) - return [output.outputs for output in output_list] + # placeholder for output value, + # populate on rank 0 and then broadcast. + return_value_local: list[vllm.RequestOutput] | list[None] + self.barriers.before_generation.wait() + + if dist.get_rank() == 0: + assert self.vllm_train_step is not None + return_value_local = self.vllm_llm.generate( + prompts, + sampling_params, + lora_request=self.lora_request, + use_tqdm=True, + ) + assert len(return_value_local) == len(prompts) + + else: + # torch requires placeholder output lists of same length as src. + return_value_local = [None] * len(prompts) + + self.barriers.after_generation.wait() + + dist.broadcast_object_list(return_value_local) + return_value: list[vllm.RequestOutput] = [] + for broadcasted_item in return_value_local: + assert broadcasted_item is not None + return_value.append(broadcasted_item) + + return return_value diff --git a/vectorlm/utils/save_utils.py b/vectorlm/utils/save_utils.py index 44b580e..594e0b2 100644 --- a/vectorlm/utils/save_utils.py +++ b/vectorlm/utils/save_utils.py @@ -201,8 +201,7 @@ def save_peft_adapter( StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True), ): - if dist.get_rank() == 0: - model.save_pretrained(output_path) + model.save_pretrained(output_path) def save_optimizer( From 112ea3c21546bec3695b41f2eba6f25bec05092f Mon Sep 17 00:00:00 2001 From: jacobthebanana Date: Thu, 9 May 2024 15:30:50 -0400 Subject: [PATCH 76/89] vllm integration [WIP]: Moved sampler-related logic into Trainer. --- configs/config_gemma.yaml | 2 +- examples/llama_example.py | 21 +++------ examples/llama_example_mp.py | 25 +--------- vectorlm/sampling/__init__.py | 1 + vectorlm/sampling/abstract.py | 4 +- vectorlm/sampling/sampling_lora.py | 51 ++++++++++----------- vectorlm/sampling/utils.py | 73 ++++++++++++++++++++++++++---- vectorlm/trainer.py | 23 +++++++++- vectorlm/utils/save_utils.py | 5 +- 9 files changed, 124 insertions(+), 81 deletions(-) diff --git a/configs/config_gemma.yaml b/configs/config_gemma.yaml index 6668cbd..2e9d16e 100644 --- a/configs/config_gemma.yaml +++ b/configs/config_gemma.yaml @@ -36,7 +36,7 @@ train_parameters: # Optimizer optimizer: - lr: 5.0e-5 + lr: 1.0e-4 weight_decay: 0.1 betas: [0.9, 0.95] eps: 1.0e-5 diff --git a/examples/llama_example.py b/examples/llama_example.py index f2f4b8e..09d888b 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -193,37 +193,30 @@ def main( is_peft_adapter_restored, ) - # Checkpoint check. Always call before training. - # If no checkpoint, it returns 0. - checkpointed_epoch = trainer.find_checkpoint(training_args.output_dir) - if sampler_config is not None: + # vllm_llm is required only on rank 0. vllm_llm = get_vllm_llm() if get_vllm_llm is not None else None - print("Initializing sampling_engine") sampling_engine = LoRASamplingEngine( trainer, vllm_llm, # required only for rank 0 SamplingParams(seed=0, temperature=0), barriers, ) + trainer.sampling_engine = sampling_engine + + # Checkpoint check. Always call before training. + # If no checkpoint, it returns 0. + checkpointed_epoch = trainer.find_checkpoint(training_args.output_dir) for epoch in range(checkpointed_epoch, training_args.epochs): train_dl_iterator = iter(dataset.train_dataloader) - for index in tqdm( + for _ in tqdm( range(len(dataset.train_dataloader)), disable=rank != 0, file=sys.__stdout__, ): batch = next(train_dl_iterator) trainer.step(batch, epoch) - - if (sampler_config is not None) and ( - index % training_args.sampler.sample_frequency == 0 - ): - sampling_engine.update(trainer) - output = sampling_engine.generate(sampler_config.prompts) - print(output[0].prompt + output[0].outputs[0].text) - dist.barrier() if epoch == training_args.epochs - 1: diff --git a/examples/llama_example_mp.py b/examples/llama_example_mp.py index c5521f6..9a74efb 100644 --- a/examples/llama_example_mp.py +++ b/examples/llama_example_mp.py @@ -23,7 +23,7 @@ import multiprocessing.context import multiprocessing.managers import threading -from functools import partial, wraps +from functools import partial from typing import TYPE_CHECKING, Callable if TYPE_CHECKING: @@ -105,29 +105,6 @@ def initialize_engine(self) -> None: self.llm = ManagedLLM(self.llm_engine) print(f"Instantiated ManagedLLM: {self.llm}") - @wraps(LLM.generate) - def generate( - self, - *args, # noqa: ANN002 - **kwargs, # noqa: ANN003 - ) -> list[vllm.RequestOutput]: - """Invoke self.llm.generate. - - All args and kwargs are forwarded to llm.generate. - - Before invoking this method, make sure no other vectorlm threads - is using the GPU. This method blocks until vLLM finishes running - completely. - - Note that it might be more elegant to use generate instead of - directly invoking LLM.generate, so that this implementation can - handle the broadcasting and synchronization safely. However, - doing so would prevent some IDE from inferring the argument types - correctly. - """ - assert self.llm is not None - return self.llm.generate(*args, **kwargs) - def get_vllm_llm(self) -> LLM: """Return LLM instance. diff --git a/vectorlm/sampling/__init__.py b/vectorlm/sampling/__init__.py index ebdc603..7810e25 100644 --- a/vectorlm/sampling/__init__.py +++ b/vectorlm/sampling/__init__.py @@ -1,2 +1,3 @@ from .abstract import AbstractSamplingEngine from .sampling_lora import LoRASamplingEngine +from .utils import handle_sample, multiprocess_wrap diff --git a/vectorlm/sampling/abstract.py b/vectorlm/sampling/abstract.py index 961f350..c9f105a 100644 --- a/vectorlm/sampling/abstract.py +++ b/vectorlm/sampling/abstract.py @@ -7,9 +7,9 @@ import vllm -from vectorlm.trainer import Trainer - if TYPE_CHECKING: + from vectorlm.trainer import Trainer + from .utils import SynchronizationBarriers diff --git a/vectorlm/sampling/sampling_lora.py b/vectorlm/sampling/sampling_lora.py index 45d7b99..368d11a 100644 --- a/vectorlm/sampling/sampling_lora.py +++ b/vectorlm/sampling/sampling_lora.py @@ -1,17 +1,20 @@ from __future__ import annotations import os +from typing import TYPE_CHECKING, Callable import torch import torch.distributed as dist import vllm from vllm.lora.request import LoRARequest -from vectorlm.trainer import Trainer from vectorlm.utils.save_utils import save_peft_adapter from .abstract import AbstractSamplingEngine -from .utils import SynchronizationBarriers +from .utils import SynchronizationBarriers, multiprocess_wrap + +if TYPE_CHECKING: + from vectorlm.trainer import Trainer class LoRASamplingEngine(AbstractSamplingEngine): @@ -54,6 +57,14 @@ def __init__( if dist.get_rank() == 0: assert vllm_llm is not None self.vllm_llm = vllm_llm + generate_fn_raw = vllm_llm.generate + else: + # placeholder, as the wrapped_fn won't be invoked outside rank-0. + generate_fn_raw: Callable[..., list[vllm.RequestOutput]] = ( + lambda: None + ) # type: ignore [] + + self.generate_fn = multiprocess_wrap(generate_fn_raw, self.barriers) # Trigger FSDP initialization before retrieving weights. # Otherwise FSDP is_root flag might be set incorrectly. @@ -76,6 +87,7 @@ def update(self, trainer: Trainer | None = None) -> None: wrapped_model = self.trainer.model assert wrapped_model is not None + self.barriers.before_generation.wait() if self.vllm_train_step != self.trainer.tr_step: save_peft_adapter(wrapped_model, self.adapter_temp_folder) assert self.trainer.tr_step is not None @@ -87,6 +99,8 @@ def update(self, trainer: Trainer | None = None) -> None: self.adapter_temp_folder, ) + self.barriers.after_generation.wait() + def generate( self, prompts: list[str], @@ -107,31 +121,12 @@ def generate( Output from vllm: list[vllm.RequestOutput], one for each prompt. """ - # placeholder for output value, - # populate on rank 0 and then broadcast. - return_value_local: list[vllm.RequestOutput] | list[None] - self.barriers.before_generation.wait() - - if dist.get_rank() == 0: - assert self.vllm_train_step is not None - return_value_local = self.vllm_llm.generate( - prompts, - sampling_params, - lora_request=self.lora_request, - use_tqdm=True, - ) - assert len(return_value_local) == len(prompts) - - else: - # torch requires placeholder output lists of same length as src. - return_value_local = [None] * len(prompts) - - self.barriers.after_generation.wait() - - dist.broadcast_object_list(return_value_local) - return_value: list[vllm.RequestOutput] = [] - for broadcasted_item in return_value_local: - assert broadcasted_item is not None - return_value.append(broadcasted_item) + return_value = self.generate_fn( + prompts, + sampling_params, + lora_request=self.lora_request, + use_tqdm=True, + ) + assert len(return_value) == len(prompts) return return_value diff --git a/vectorlm/sampling/utils.py b/vectorlm/sampling/utils.py index 6f5b441..5c18576 100644 --- a/vectorlm/sampling/utils.py +++ b/vectorlm/sampling/utils.py @@ -7,13 +7,9 @@ import threading import time from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Iterable, NamedTuple +from typing import TYPE_CHECKING, Any, Callable, Iterable, NamedTuple, TypeVar -if TYPE_CHECKING: - from vllm import LLMEngine - from vllm.worker.worker_base import WorkerBase - -from vllm import LLM, SamplingParams +from vllm import LLM from vllm.engine.local_worker_utils import ( LocalWorkerVllm, ResultHandler, @@ -30,12 +26,14 @@ ) from vllm.worker.worker import init_worker_distributed_environment +from .abstract import AbstractSamplingEngine + if TYPE_CHECKING: from threading import Barrier + from vllm import LLMEngine, SamplingParams from vllm.engine.arg_utils import EngineConfig - -from .abstract import AbstractSamplingEngine + from vllm.worker.worker_base import WorkerBase class SampleOutput(NamedTuple): @@ -176,6 +174,7 @@ def handle_sample( sampling_engine: an instantiation of sampling engine. prompts: a list (iterable) of prompts. output_path: if provided, append output json lines to this file. + Recommended: specify output_path only on rank 0. sampling_params: forwarded to sampling engine. extra_data: prepended to each line of output (e.g., current epoch.) @@ -192,11 +191,11 @@ def handle_sample( # Parse sample engine output and keep only the output strings. sample_outputs: list[SampleOutput] = [] - for prompt, options in zip(prompts, generation_output): + for prompt, request_output in zip(prompts, generation_output): sample_outputs.append( SampleOutput( prompt, - [option.text for option in options], + [option.text for option in request_output.outputs], time_taken, ), ) @@ -234,3 +233,57 @@ def get_vllm_worker_factory( vision_language_config=engine_config.vision_language_config, tensorizer_config=engine_config.tensorizer_config, ) + + +Fn = TypeVar("Fn", bound=Callable[..., Any]) + + +def multiprocess_wrap(fn: Fn, barriers: SynchronizationBarriers) -> Fn: + """Apply barrier to function and broadcast output. + + This wrapper function tries to preserve the type signature + of the wrapped function for the IDE. Tested for Pylance. + + While fn would be invoked only on rank 0, the wrapped function + should be invoked in the vectorlm thread at all ranks, so that + the barriers would block these threads from accessing GPU while + the fn is running. + + Each rank would receive the same value as output. + + Params: + ------- + fn: Function to wrap. Output needs to be compatible with pickle. + barriers: SynchronizationBarriers, only the before_generation and + after_generation barriers are required.. + + Returns + ------- + same output as Fn, but broadcasted to all ranks + (i.e., same value at all ranks) + + """ + + def _wrapped_fn(*args, **kwargs) -> ...: # noqa: ANN002,ANN003 + barriers.after_generation.wait() + + import torch.distributed + + rank = torch.distributed.get_rank() + + # placeholder for output value, + # populate on rank 0 and then broadcast. + # torch requires placeholder element in object list. + output = [None] + if rank == 0: + output = [fn(*args, **kwargs)] + + # fn might access torch.dist, which might conflict with + # broadcast_object_list. Hence, keep all ranks witing until fn returns. + # on rank 0. + barriers.before_generation.wait() + + torch.distributed.broadcast_object_list(output) + return output[0] + + return _wrapped_fn # type: ignore[] diff --git a/vectorlm/trainer.py b/vectorlm/trainer.py index 30165eb..4b7a01b 100644 --- a/vectorlm/trainer.py +++ b/vectorlm/trainer.py @@ -2,7 +2,7 @@ import math import os -from typing import Any +from typing import TYPE_CHECKING, Any import peft import torch @@ -13,6 +13,7 @@ import wandb from vectorlm.dataset import Dataset +from vectorlm.sampling import handle_sample from vectorlm.utils.data_utils import Config from vectorlm.utils.save_utils import ( checkpoint_exists, @@ -28,6 +29,9 @@ save_scheduler, ) +if TYPE_CHECKING: + from vectorlm.sampling import AbstractSamplingEngine + class Trainer: """Main trainer class. @@ -90,6 +94,7 @@ def __init__( self.max_steps = None self.saving_steps = None self._post_process(original_dataset_length) + self.sampling_engine: AbstractSamplingEngine | None = None if hasattr(self.config, "lora_peft_config"): self.peft_method = peft.utils.peft_types.PeftType.LORA @@ -268,6 +273,22 @@ def step( test_loss = None if self.tr_step % self.logging_steps == 0: test_loss = self.eval_step(epoch) + + if (self.sampling_engine is not None) and ( + self.tr_step % self.config.sampler.sample_frequency == 0 + ): + self.sampling_engine.update(self) + handle_sample( + self.sampling_engine, + self.config.sampler.prompts, + output_path=( + self.config.sampler.output_jsonl_path + if dist.get_rank() == 0 + else None + ), + extra_data={"train_step": self.tr_step}, + ) + self.tr_step += 1 return train_loss, test_loss diff --git a/vectorlm/utils/save_utils.py b/vectorlm/utils/save_utils.py index 594e0b2..53db11f 100644 --- a/vectorlm/utils/save_utils.py +++ b/vectorlm/utils/save_utils.py @@ -201,7 +201,10 @@ def save_peft_adapter( StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True), ): - model.save_pretrained(output_path) + model.save_pretrained( + output_path, + is_main_process=(dist.get_rank() == 0), + ) def save_optimizer( From e707987bf6f3cd86a745a1089c0e959a0be8dc28 Mon Sep 17 00:00:00 2001 From: jacobthebanana Date: Thu, 9 May 2024 20:34:42 -0400 Subject: [PATCH 77/89] vllm integration: Added documentation on sampling engine. --- docs/sampling.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 docs/sampling.md diff --git a/docs/sampling.md b/docs/sampling.md new file mode 100644 index 0000000..dff995f --- /dev/null +++ b/docs/sampling.md @@ -0,0 +1,25 @@ +# Efficient Sampling during training + +Some training objectives, noteably PPO, require "sampling" from the language model many times during training. One possible approach is to invoke the model.generate that HuggingFace provides. At the same time, there have been a number of efficient inference approaches, including vLLM and others, that achieve over 10x the sampling throughput in terms of tokens generated per second when using a large sampling batch size. If model.generate is taking up too much of the training time, it might be worthwhile looking into these third-party solutions for speeding up the sampling process. + +One main challenge of running these third-party solutions, however, is that most of them assume that the weights of the language model are fixed and do not provide a straightforward way of updating these weights without restarting the sampling engine, which sometimes take minutes. However, the performance of PPO and similar techniques heavily rely on the ability to replace the weights efficiently, or else the training would no longer be on-policy and convergence would take substantially more training steps. + +Additionally, it is not straightforward to ensure a consistently high GPU utilization throughout both training and sampling. Existing approaches, such as [OpenRLHF Ray PPO](https://github.com/OpenLLMAI/OpenRLHF/blob/adf26867e44765a3963b4e8d249cf58a5162209c/examples/train_ppo_ray.py), uses some GPUs exclusively for sampling so that none of the GPUs would be idle while others are busy running training/inference. This repository enables you to make the most out of your GPUs by fitting vLLM and your training loop into the same set of devices. This way, none of the GPUs would sit idle- if a GPU is not running training, it would be busy sampling using vLLM. These slides ([link](https://docs.google.com/presentation/d/1FCa5O8RYYkRRCAAcXhqCvomePo5fEfhjQciSteTEJ30/edit?usp=sharing)) provide an overview of the architecture behind this approach. + +## Example- Supervised fine-tuning + +We provide a basic example that samples from the language model while fine-tuning using a basic causal language modelling objective. To run the example, uncomment the "sampler" section in your configuration yaml, choose a port for `nccl` coordination, and run the following command (not using torchrun): + +``` +export MASTER_ADDR=127.0.0.1 +export MASTER_PORT=19132 +python3 examples/llama_example_mp.py \ +--yaml_path configs/config.yaml \ +--world_size 2 +``` + +## Bring your own training loop + +While the reference implementation is only for supervised fine-tuning, we provide abstractions that make it easier for you to implement your own training loop- be it PPO RLHF, TWIST, or something else. The goal is to abstract away all the synchronization logic, so that a training loop you've built on one GPU could scale to multiple GPUs on the same server with minimal modifications. + +To get started, refer to examples/llama_example.py and vectorlm/trainer.py. Usually, the vLLM Engine is accessible only from the rank 0, making synchronization challenging. When invoked through llama_example_mp, the `SamplingEngine` interface in VectorLM enables your training loop to access vLLM.LLM.generate from all ranks, returning the same result across all ranks. Note that because the synchronization barriers require all ranks to reach the synchronization point, you need to invoke `generate` from all ranks. From 61c39ade355e9fa7743d1dbc481d940aac4da876 Mon Sep 17 00:00:00 2001 From: jacobthebanana Date: Thu, 9 May 2024 20:45:02 -0400 Subject: [PATCH 78/89] vllm integration: Added documentation on sampling engine. --- docs/sampling.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/sampling.md b/docs/sampling.md index dff995f..90cf4c5 100644 --- a/docs/sampling.md +++ b/docs/sampling.md @@ -1,10 +1,11 @@ # Efficient Sampling during training -Some training objectives, noteably PPO, require "sampling" from the language model many times during training. One possible approach is to invoke the model.generate that HuggingFace provides. At the same time, there have been a number of efficient inference approaches, including vLLM and others, that achieve over 10x the sampling throughput in terms of tokens generated per second when using a large sampling batch size. If model.generate is taking up too much of the training time, it might be worthwhile looking into these third-party solutions for speeding up the sampling process. +Some training objectives, noteably PPO, require "sampling" from the language model many times during training. The most straightforward approach might be to invoke model.generate on the model from within the training loop. Nevertheless, there have been a number of alternative inference approaches, including vLLM and others, promising over 10x the sampling throughput in terms of tokens generated per second when using a large sampling batch size. If model.generate is taking up too much of the training time, it might be worthwhile looking into these third-party solutions for speeding up the sampling process. -One main challenge of running these third-party solutions, however, is that most of them assume that the weights of the language model are fixed and do not provide a straightforward way of updating these weights without restarting the sampling engine, which sometimes take minutes. However, the performance of PPO and similar techniques heavily rely on the ability to replace the weights efficiently, or else the training would no longer be on-policy and convergence would take substantially more training steps. +One main challenge of running these third-party solutions, however, is that most of them assume that the weights of the language model are fixed, such that there isn't a straightforward way of updating these weights. Usually, updating the weights requires restarting the sampling engine, which sometimes take minutes. At the same time, the performance of PPO and similar techniques heavily rely on the ability to replace the weights efficiently, or else the training would no longer be on-policy and convergence would take substantially more training steps. To resolve this issue, we implemented techniques to "hot-swap" the model parameters that are used in the sampling process. -Additionally, it is not straightforward to ensure a consistently high GPU utilization throughout both training and sampling. Existing approaches, such as [OpenRLHF Ray PPO](https://github.com/OpenLLMAI/OpenRLHF/blob/adf26867e44765a3963b4e8d249cf58a5162209c/examples/train_ppo_ray.py), uses some GPUs exclusively for sampling so that none of the GPUs would be idle while others are busy running training/inference. This repository enables you to make the most out of your GPUs by fitting vLLM and your training loop into the same set of devices. This way, none of the GPUs would sit idle- if a GPU is not running training, it would be busy sampling using vLLM. These slides ([link](https://docs.google.com/presentation/d/1FCa5O8RYYkRRCAAcXhqCvomePo5fEfhjQciSteTEJ30/edit?usp=sharing)) provide an overview of the architecture behind this approach. +Additionally, it is not straightforward to ensure a consistently high GPU utilization when combining sampling with training. +This repository enables you to make the most out of all your GPUs by fitting vLLM and your training loop into the same set of devices. This way, none of the GPUs would sit idle- if a GPU is not running training, it would be busy sampling using vLLM. These slides ([link](https://docs.google.com/presentation/d/1FCa5O8RYYkRRCAAcXhqCvomePo5fEfhjQciSteTEJ30/edit?usp=sharing)) provide an overview of the architecture behind this approach. ## Example- Supervised fine-tuning From 609c023af14565f117bb80d6abea25d62cf5496b Mon Sep 17 00:00:00 2001 From: jacobthebanana Date: Thu, 23 May 2024 13:18:27 -0400 Subject: [PATCH 79/89] [WIP] vllm hotswapping: Implement minimum-viable wrapper for vllm/main. Cleanup is required. --- examples/llama_example.py | 14 +- examples/llama_example_mp.py | 137 ++----------------- vectorlm/sampling/utils.py | 178 +++++++++++++------------ vectorlm/sampling/vllm_worker_utils.py | 63 +++++++++ 4 files changed, 175 insertions(+), 217 deletions(-) create mode 100644 vectorlm/sampling/vllm_worker_utils.py diff --git a/examples/llama_example.py b/examples/llama_example.py index cb2c87a..210e3f1 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -58,20 +58,20 @@ def parse_args() -> Namespace: def main( config: Config, - local_rank: int | None = None, world_size: int | None = None, - barriers: SynchronizationBarriers | None = None, get_vllm_llm: Callable[[], LLM] | None = None, + barriers: SynchronizationBarriers | None = None, + local_rank: int | None = None, ) -> None: """Define the main calling function. Args: ---- config: vectorlm config, e.g., loaded from yaml - local_rank: int, where 0 is root process, one process per accelerator. world_size: number of processes. - barriers: SynchronizationBarriers, required for all processes. get_vllm_llm: required only for root process (rank 0). + barriers: SynchronizationBarriers, required for all processes. + local_rank: int, where 0 is root process, one process per accelerator. """ if barriers is not None: @@ -195,7 +195,11 @@ def main( if sampler_config is not None: # vllm_llm is required only on rank 0. - vllm_llm = get_vllm_llm() if get_vllm_llm is not None else None + vllm_llm = ( + get_vllm_llm() + if (get_vllm_llm is not None) and (rank == 0) + else None + ) sampling_engine = LoRASamplingEngine( trainer, vllm_llm, # required only for rank 0 diff --git a/examples/llama_example_mp.py b/examples/llama_example_mp.py index 3e6cd08..c39a5fd 100644 --- a/examples/llama_example_mp.py +++ b/examples/llama_example_mp.py @@ -19,30 +19,19 @@ from __future__ import annotations import argparse -import multiprocessing -import multiprocessing.context -import multiprocessing.managers -import threading from functools import partial -from typing import TYPE_CHECKING, Callable - -if TYPE_CHECKING: - from vllm.worker.worker_base import WorkerBase +from typing import Callable from llama_example import main from vllm.engine.arg_utils import EngineArgs, EngineConfig from vllm.engine.llm_engine import LLMEngine -from vllm.engine.local_worker_utils import LocalWorkerVllm, ResultHandler from vllm.entrypoints.llm import LLM -from vllm.worker.worker import init_worker_distributed_environment +from vllm.executor.multiproc_worker_utils import ResultHandler, mp from vectorlm.sampling.utils import ( ManagedLLM, ManagedMultiProcGPUExecutor, SynchronizationBarriers, - _ensure_torch_dist_is_initialized, - _get_rdvz_url, - get_vllm_worker_factory, ) from vectorlm.utils.data_utils import Config @@ -55,8 +44,6 @@ class _VLLMCallbackWrapper: def __init__( self, - non_driver_workers: list[VectorLMWorker], - vllm_result_handler: ResultHandler, engine_config: EngineConfig, vectorlm_config: Config, world_size: int, @@ -65,23 +52,16 @@ def __init__( """Instantiate class without initializing wrapped vLLM engine.""" self.llm_engine: LLMEngine | None = None self.llm: LLM | None = None - self.non_driver_workers = non_driver_workers - self.vllm_result_handler = vllm_result_handler self.engine_config = engine_config self.barriers = barriers - # Might not be required since LLM.generate is blocking. - # torch.dist.barrier might be sufficient for blocking - # other worker processes. - self.gpu_access_lock = threading.Lock() - - self.vectorlm_main_fn = partial( + # Only missing args is local_rank. + self.vectorlm_fn: Callable[[int], None] = partial( main, vectorlm_config, - 0, world_size, - self.barriers, self.get_vllm_llm, + self.barriers, ) def initialize_engine(self) -> None: @@ -89,11 +69,7 @@ def initialize_engine(self) -> None: Invoke this method only after vLLM workers are all ready. """ - ManagedMultiProcGPUExecutor.workers = tuple( - self.non_driver_workers, - ) - ManagedMultiProcGPUExecutor.vectorlm_main_fn = self.vectorlm_main_fn - ManagedMultiProcGPUExecutor.result_handler = self.vllm_result_handler + ManagedMultiProcGPUExecutor.vectorlm_fn = self.vectorlm_fn self.llm_engine = LLMEngine( **self.engine_config.to_dict(), @@ -125,73 +101,7 @@ def join_vectorlm_thread(self) -> None: assert self.llm_engine is not None model_executor = self.llm_engine.model_executor assert isinstance(model_executor, ManagedMultiProcGPUExecutor) - - model_executor.vectorlm_thread.join() - - -class VectorLMWorker(LocalWorkerVllm): - """Worker for running VectorLM logic alongside vLLM worker. - - Important: do not use this instance for the rank 0 (root) process. - - Note that nccl requires that only one process may have access - to each GPU. Each LocalWorkerVllm is a multiprocessing.Process. - Vectorlm logic would be launched as a thread within each of these - proceses. - - Spawn no more than one such instance for each GPU. - """ - - def __init__( - self, - result_handler: ResultHandler, - worker_factory: Callable[[], WorkerBase], - vllm_engine_config: EngineConfig, - vectorlm_config: Config, - local_rank: int, - world_size: int, - barriers: SynchronizationBarriers, - ) -> None: - """Instantiate LocalWorkerVllm wrapper. - - vectorlm_dist_init_barrier ensures that torch.dist is initialized in - the vectorlm thread and not the main thread (vllm) of the process. - """ - self.vllm_engine_config = vllm_engine_config - self.gpu_access_lock = threading.Lock() - self.vectorlm_config = vectorlm_config - self.local_rank = local_rank - self.world_size = world_size - self.barriers = barriers - - super().__init__(result_handler, worker_factory) - - def run(self) -> None: - """Launch vectorlm logic in a separate thread.""" - print(f"rank {self.local_rank}: init_worker_dist started") - init_worker_distributed_environment( - self.vllm_engine_config.parallel_config, - self.local_rank, - _get_rdvz_url(), - self.local_rank, - ) - print(f"rank {self.local_rank}: init_worker_dist completed") - - _ensure_torch_dist_is_initialized() - - self.vectorlm_thread = threading.Thread( - target=main, - args=( - self.vectorlm_config, - self.local_rank, - self.world_size, - self.barriers, - ), - name=f"rank-{self.local_rank}/vectorlm", - ) - self.vectorlm_thread.start() - - super().run() + model_executor.rank_0_vectorlm_thread.join() if __name__ == "__main__": @@ -219,51 +129,30 @@ def run(self) -> None: # threads as long as vLLM tasks are running. barriers = SynchronizationBarriers( # (n+1) threads: __main__, and n vectorlm threads (including main). - multiprocessing.Barrier(world_size + 1), + mp.Barrier(world_size + 1), # n vectorlm threads. - multiprocessing.Barrier(world_size), - multiprocessing.Barrier(world_size), + mp.Barrier(world_size), + mp.Barrier(world_size), ) vllm_result_handler = ResultHandler() # rank 0 worker runs in the __main__ process. # all other ranks use one process each. # vectorlm logic in each ranks (including rank 0) is in a separate thread. - non_driver_workers: list[VectorLMWorker] = [ - VectorLMWorker( - vllm_result_handler, - get_vllm_worker_factory( - vllm_engine_config, - _get_rdvz_url(), - local_rank, - ), - vllm_engine_config, - vectorlm_config, - local_rank, - world_size=world_size, - barriers=barriers, - ) - for local_rank in range(1, world_size) - ] vllm_callback_wrapper = _VLLMCallbackWrapper( - non_driver_workers, - vllm_result_handler, vllm_engine_config, vectorlm_config, world_size, barriers, ) - for worker in non_driver_workers: - worker.start() - vllm_callback_wrapper.initialize_engine() assert vllm_callback_wrapper.llm is not None + output = vllm_callback_wrapper.llm.generate("Vector Institute is") + print(output[0].prompt + output[0].outputs[0].text) + print("main: vllm_init_barrier waiting") barriers.vllm_init.wait() print("main: vllm_init_barrier cleared") - output = vllm_callback_wrapper.llm.generate("Vector Institute is") - print(output[0].prompt + output[0].outputs[0].text) - vllm_callback_wrapper.join_vectorlm_thread() diff --git a/vectorlm/sampling/utils.py b/vectorlm/sampling/utils.py index 5c18576..7306979 100644 --- a/vectorlm/sampling/utils.py +++ b/vectorlm/sampling/utils.py @@ -10,30 +10,27 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, NamedTuple, TypeVar from vllm import LLM -from vllm.engine.local_worker_utils import ( - LocalWorkerVllm, +from vllm.executor.multiproc_gpu_executor import MultiprocessingGPUExecutor +from vllm.executor.multiproc_worker_utils import ( ResultHandler, WorkerMonitor, ) -from vllm.executor.multiproc_gpu_executor import ( - MultiProcGPUExecutor, - _create_worker, -) from vllm.utils import ( Counter, get_distributed_init_method, - set_cuda_visible_devices, + get_ip, + get_open_port, + get_vllm_instance_id, ) -from vllm.worker.worker import init_worker_distributed_environment +from vllm.worker.worker import Worker from .abstract import AbstractSamplingEngine +from .vllm_worker_utils import ManagedProcessWorkerWrapper if TYPE_CHECKING: from threading import Barrier from vllm import LLMEngine, SamplingParams - from vllm.engine.arg_utils import EngineConfig - from vllm.worker.worker_base import WorkerBase class SampleOutput(NamedTuple): @@ -69,34 +66,16 @@ class SynchronizationBarriers(NamedTuple): after_generation: Barrier -def _ensure_torch_dist_is_initialized() -> None: - import torch.distributed - - assert torch.distributed.is_initialized() - - -def _get_rdvz_url() -> str: - """Obtain rendezvous url for Torch dist.""" - return get_distributed_init_method( - os.environ.get("MASTER_ADDR", "127.0.0.1"), - int(os.environ["MASTER_PORT"]), - ) +class ManagedMultiProcGPUExecutor(MultiprocessingGPUExecutor): + """MultiProcGPUExecutor, but with VectorLM supplied.""" - -class ManagedMultiProcGPUExecutor(MultiProcGPUExecutor): - """MultiProcGPUExecutor, but with worker processes instantiated outside.""" - - workers: tuple[LocalWorkerVllm, ...] | None = None - vectorlm_main_fn: Callable[[], None] | None = None - result_handler: ResultHandler | None = None + vectorlm_fn: Callable[[int], None] def _init_executor(self) -> None: - """Initialize executor without initializing workers. - - Same as MultiProcGPUExecutor but assumes self.workers is already set. + """Launch vectorlm logic in workers. - Mostly reproduced from - vllm/vllm-ray-optional/vllm/executor/multiproc_gpu_executor.py + Supply barriers and pickle-compatible vectorlm main fn to + workers via vLLM multiprocessing messaging mechanisms. """ assert ( not self.speculative_config @@ -104,14 +83,15 @@ def _init_executor(self) -> None: # Create the parallel GPU workers. world_size = self.parallel_config.tensor_parallel_size - assert self.workers is not None - assert len(self.workers) == world_size - 1, ( - f"non-driver workers len(self.workers): {len(self.workers)} " - f"should be (world_size - 1) {world_size - 1}" - ) + # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers if "CUDA_VISIBLE_DEVICES" not in os.environ: - set_cuda_visible_devices(range(world_size)) + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( + map(str, range(world_size)) + ) + + # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers + os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() from torch.cuda import device_count @@ -119,36 +99,80 @@ def _init_executor(self) -> None: world_size <= device_count() ), "please set tensor_parallel_size to less than max local gpu count" - assert self.result_handler is not None - self.worker_monitor = WorkerMonitor( - list(self.workers), - self.result_handler, + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port() + ) + + if world_size == 1: + self.workers = [] + else: + result_handler = ResultHandler() + self.workers = [ + ManagedProcessWorkerWrapper( + result_handler, + partial( + self._create_worker, + rank=rank, + local_rank=rank, + distributed_init_method=distributed_init_method, + ), + partial(self.vectorlm_fn, rank), + ) + for rank in range(1, world_size) + ] + + self.worker_monitor = WorkerMonitor(self.workers, result_handler) + result_handler.start() + self.worker_monitor.start() + + self.driver_worker = self._create_worker( + distributed_init_method=distributed_init_method, ) - self.result_handler.start() - self.worker_monitor.start() - - distributed_init_method = _get_rdvz_url() - - # driver worker is of rank 0 - print("driver worker: init_worker_dist started") - init_worker_distributed_environment( - self.parallel_config, - 0, - distributed_init_method, - 0, + self.rank_0_vectorlm_thread = threading.Thread( + target=partial(self.vectorlm_fn, 0), ) - print("driver worker: init_worker_dist completed") - _ensure_torch_dist_is_initialized() - - # start vectorlm logic in the same Python process - # (albeit in a separate thread) - self.vectorlm_thread = threading.Thread( - target=self.vectorlm_main_fn, - name="driver/vectorlm", + self.rank_0_vectorlm_thread.start() + + self._run_workers("init_device") + self._run_workers( + "load_model", + max_concurrent_workers=self.parallel_config.max_parallel_loading_workers, ) - self.vectorlm_thread.start() - self._init_driver_worker_and_model(0, 0, distributed_init_method) + +class VectorLMWorker(Worker): + """Worker for running VectorLM logic alongside vLLM worker. + + Use this instance for the rank 0 (root) process. + + Note that nccl requires that only one process may have access + to each GPU. Each LocalWorkerVllm is a multiprocessing.Process. + Vectorlm logic would be launched as a thread within each of these + proceses. + + Spawn no more than one such instance for each GPU. + + Attributes + ---------- + vectorlm_thread: threading.Thread. + + """ + + barriers: SynchronizationBarriers + vectorlm_fn: Callable[[SynchronizationBarriers, int], None] + + def launch_vectorlm(self) -> None: + """Launch vectorlm logic in a separate thread. + + Params: + ------ + vectorlm_fn: VectorLM logic. Requires no argument. Be sure to + populate all arguments via functools.partial. + barriers: SynchronizationBarriers for synchronizing VectorLM + and vLLM access to NCCL. + """ + assert hasattr(self, "barriers") + assert hasattr(self, "vectorlm_fn") class ManagedLLM(LLM): @@ -213,28 +237,6 @@ def handle_sample( return sample_outputs -def get_vllm_worker_factory( - engine_config: EngineConfig, - distributed_init_method: str, - rank: int, -) -> Callable[[], WorkerBase]: - """Initialize vLLM worker.""" - return partial( - _create_worker, - model_config=engine_config.model_config, - parallel_config=engine_config.parallel_config, - scheduler_config=engine_config.scheduler_config, - device_config=engine_config.device_config, - cache_config=engine_config.cache_config, - local_rank=rank, - rank=rank, - distributed_init_method=distributed_init_method, - lora_config=engine_config.lora_config, - vision_language_config=engine_config.vision_language_config, - tensorizer_config=engine_config.tensorizer_config, - ) - - Fn = TypeVar("Fn", bound=Callable[..., Any]) diff --git a/vectorlm/sampling/vllm_worker_utils.py b/vectorlm/sampling/vllm_worker_utils.py new file mode 100644 index 0000000..f71ba48 --- /dev/null +++ b/vectorlm/sampling/vllm_worker_utils.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import threading +from typing import TYPE_CHECKING, Any, Callable + +from vllm.executor.multiproc_worker_utils import ( + ProcessWorkerWrapper, + ResultHandler, + _run_worker_process, + mp, +) +from vllm.logger import init_logger + +if TYPE_CHECKING: + from multiprocessing import Queue + from multiprocessing.process import BaseProcess + +logger = init_logger(__name__) +JOIN_TIMEOUT_S = 2 + + +class ManagedProcessWorkerWrapper(ProcessWorkerWrapper): + """Wrap ProcessWorkerWrapper to add vectorlm thread to vllm process.""" + + def __init__( + self, + result_handler: ResultHandler, + worker_factory: Callable[[], Any], + vectorlm_fn: Callable[[], None], + ) -> None: + """Initialize multiprocessing queues and launch worker process.""" + self._task_queue = mp.Queue() + self.result_queue = result_handler.result_queue + self.tasks = result_handler.tasks + + self.process: BaseProcess = mp.Process( # type: ignore[attr-defined] + target=_run_worker_process_and_vectorlm_thread, + name="VllmWorkerProcess", + kwargs={ + "worker_factory": worker_factory, + "task_queue": self._task_queue, + "result_queue": self.result_queue, + "vectorlm_fn": vectorlm_fn, + }, + daemon=True, + ) + + self.process.start() + + +def _run_worker_process_and_vectorlm_thread( + worker_factory: Callable[[], Any], + task_queue: Queue, + result_queue: Queue, + vectorlm_fn: Callable[[], None], +) -> None: + """Invoke _run_worker_process and vectorlm logic in separate thread.""" + # Add process-specific prefix to stdout and stderr + + vectorlm_thread = threading.Thread(target=vectorlm_fn) + vectorlm_thread.start() + + _run_worker_process(worker_factory, task_queue, result_queue) From 9585c0197981fbaa3859e5e2fce0a6a834b108af Mon Sep 17 00:00:00 2001 From: jacobthebanana Date: Thu, 23 May 2024 14:14:11 -0400 Subject: [PATCH 80/89] [WIP] vllm hotswapping: Reduced area of vLLM integration interface. Cleanup is required. --- examples/llama_example_mp.py | 13 +-- vectorlm/sampling/utils.py | 141 +++++-------------------- vectorlm/sampling/vllm_worker_utils.py | 63 ----------- 3 files changed, 35 insertions(+), 182 deletions(-) delete mode 100644 vectorlm/sampling/vllm_worker_utils.py diff --git a/examples/llama_example_mp.py b/examples/llama_example_mp.py index c39a5fd..cb4fef0 100644 --- a/examples/llama_example_mp.py +++ b/examples/llama_example_mp.py @@ -101,7 +101,8 @@ def join_vectorlm_thread(self) -> None: assert self.llm_engine is not None model_executor = self.llm_engine.model_executor assert isinstance(model_executor, ManagedMultiProcGPUExecutor) - model_executor.rank_0_vectorlm_thread.join() + assert model_executor.driver_worker is not None + model_executor.driver_worker.vectorlm_thread.join() if __name__ == "__main__": @@ -112,9 +113,9 @@ def join_vectorlm_thread(self) -> None: world_size: int = args.world_size vectorlm_config = Config(yaml_path=args.yaml_path) - sampler_config = vectorlm_config.train_parameters.sampler # type: ignore[] + sampler_config = vectorlm_config.train_parameters.sampler # type: ignore[reportAttributeAccessIssue] vllm_engine_config = EngineArgs( - model=vectorlm_config.model, # type: ignore[] + model=vectorlm_config.model, # type: ignore[reportAttributeAccessIssue] gpu_memory_utilization=sampler_config.get( "gpu_memory_utilization", 0.35, @@ -129,10 +130,10 @@ def join_vectorlm_thread(self) -> None: # threads as long as vLLM tasks are running. barriers = SynchronizationBarriers( # (n+1) threads: __main__, and n vectorlm threads (including main). - mp.Barrier(world_size + 1), + vllm_init=mp.Barrier(world_size + 1), # n vectorlm threads. - mp.Barrier(world_size), - mp.Barrier(world_size), + before_generation=mp.Barrier(world_size), + after_generation=mp.Barrier(world_size), ) vllm_result_handler = ResultHandler() diff --git a/vectorlm/sampling/utils.py b/vectorlm/sampling/utils.py index 7306979..b4ed496 100644 --- a/vectorlm/sampling/utils.py +++ b/vectorlm/sampling/utils.py @@ -3,34 +3,21 @@ from __future__ import annotations import json -import os import threading import time -from functools import partial from typing import TYPE_CHECKING, Any, Callable, Iterable, NamedTuple, TypeVar from vllm import LLM from vllm.executor.multiproc_gpu_executor import MultiprocessingGPUExecutor -from vllm.executor.multiproc_worker_utils import ( - ResultHandler, - WorkerMonitor, -) -from vllm.utils import ( - Counter, - get_distributed_init_method, - get_ip, - get_open_port, - get_vllm_instance_id, -) -from vllm.worker.worker import Worker +from vllm.utils import Counter from .abstract import AbstractSamplingEngine -from .vllm_worker_utils import ManagedProcessWorkerWrapper if TYPE_CHECKING: from threading import Barrier from vllm import LLMEngine, SamplingParams + from vllm.worker.worker_base import WorkerBase class SampleOutput(NamedTuple): @@ -69,110 +56,38 @@ class SynchronizationBarriers(NamedTuple): class ManagedMultiProcGPUExecutor(MultiprocessingGPUExecutor): """MultiProcGPUExecutor, but with VectorLM supplied.""" + # only missing parameter in vectorlm_fn is local_rank. vectorlm_fn: Callable[[int], None] - def _init_executor(self) -> None: - """Launch vectorlm logic in workers. + def __init__(self, *args, **kwargs) -> None: # noqa: ANN002,ANN003 + """Copy vectorlm_fn into this instance.""" + self.vectorlm_fn = ManagedMultiProcGPUExecutor.vectorlm_fn + super().__init__(*args, **kwargs) - Supply barriers and pickle-compatible vectorlm main fn to - workers via vLLM multiprocessing messaging mechanisms. - """ - assert ( - not self.speculative_config - ), "Speculative decoding not yet supported for MultiProcGPU backend." - - # Create the parallel GPU workers. - world_size = self.parallel_config.tensor_parallel_size - - # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers - if "CUDA_VISIBLE_DEVICES" not in os.environ: - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( - map(str, range(world_size)) - ) + def _create_worker( + self, + local_rank: int = 0, + *args, # noqa: ANN002 + **kwargs, # noqa: ANN003 + ) -> WorkerBase: + """Instantiate worker and launch vectorlm thread. - # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers - os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() + For rank 0, this method is invoked "blocking" inside the rank-0 process. - from torch.cuda import device_count - - assert ( - world_size <= device_count() - ), "please set tensor_parallel_size to less than max local gpu count" - - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port() - ) - - if world_size == 1: - self.workers = [] - else: - result_handler = ResultHandler() - self.workers = [ - ManagedProcessWorkerWrapper( - result_handler, - partial( - self._create_worker, - rank=rank, - local_rank=rank, - distributed_init_method=distributed_init_method, - ), - partial(self.vectorlm_fn, rank), - ) - for rank in range(1, world_size) - ] - - self.worker_monitor = WorkerMonitor(self.workers, result_handler) - result_handler.start() - self.worker_monitor.start() - - self.driver_worker = self._create_worker( - distributed_init_method=distributed_init_method, - ) - self.rank_0_vectorlm_thread = threading.Thread( - target=partial(self.vectorlm_fn, 0), - ) - self.rank_0_vectorlm_thread.start() - - self._run_workers("init_device") - self._run_workers( - "load_model", - max_concurrent_workers=self.parallel_config.max_parallel_loading_workers, + For rank != 0, this method is supposed to be invoked in a child process + spawned from the main rank-0 process. + """ + vectorlm_thread = threading.Thread( + target=self.vectorlm_fn, + kwargs={"local_rank": local_rank}, ) + vectorlm_thread.start() + worker = super()._create_worker(*args, **kwargs, local_rank=local_rank) + assert worker is not None + worker.vectorlm_thread = vectorlm_thread -class VectorLMWorker(Worker): - """Worker for running VectorLM logic alongside vLLM worker. - - Use this instance for the rank 0 (root) process. - - Note that nccl requires that only one process may have access - to each GPU. Each LocalWorkerVllm is a multiprocessing.Process. - Vectorlm logic would be launched as a thread within each of these - proceses. - - Spawn no more than one such instance for each GPU. - - Attributes - ---------- - vectorlm_thread: threading.Thread. - - """ - - barriers: SynchronizationBarriers - vectorlm_fn: Callable[[SynchronizationBarriers, int], None] - - def launch_vectorlm(self) -> None: - """Launch vectorlm logic in a separate thread. - - Params: - ------ - vectorlm_fn: VectorLM logic. Requires no argument. Be sure to - populate all arguments via functools.partial. - barriers: SynchronizationBarriers for synchronizing VectorLM - and vLLM access to NCCL. - """ - assert hasattr(self, "barriers") - assert hasattr(self, "vectorlm_fn") + return worker class ManagedLLM(LLM): @@ -269,7 +184,7 @@ def multiprocess_wrap(fn: Fn, barriers: SynchronizationBarriers) -> Fn: def _wrapped_fn(*args, **kwargs) -> ...: # noqa: ANN002,ANN003 barriers.after_generation.wait() - import torch.distributed + import torch.distributed # type: ignore[reportMissingImports] rank = torch.distributed.get_rank() @@ -288,4 +203,4 @@ def _wrapped_fn(*args, **kwargs) -> ...: # noqa: ANN002,ANN003 torch.distributed.broadcast_object_list(output) return output[0] - return _wrapped_fn # type: ignore[] + return _wrapped_fn # type: ignore[reportReturnType] diff --git a/vectorlm/sampling/vllm_worker_utils.py b/vectorlm/sampling/vllm_worker_utils.py deleted file mode 100644 index f71ba48..0000000 --- a/vectorlm/sampling/vllm_worker_utils.py +++ /dev/null @@ -1,63 +0,0 @@ -from __future__ import annotations - -import threading -from typing import TYPE_CHECKING, Any, Callable - -from vllm.executor.multiproc_worker_utils import ( - ProcessWorkerWrapper, - ResultHandler, - _run_worker_process, - mp, -) -from vllm.logger import init_logger - -if TYPE_CHECKING: - from multiprocessing import Queue - from multiprocessing.process import BaseProcess - -logger = init_logger(__name__) -JOIN_TIMEOUT_S = 2 - - -class ManagedProcessWorkerWrapper(ProcessWorkerWrapper): - """Wrap ProcessWorkerWrapper to add vectorlm thread to vllm process.""" - - def __init__( - self, - result_handler: ResultHandler, - worker_factory: Callable[[], Any], - vectorlm_fn: Callable[[], None], - ) -> None: - """Initialize multiprocessing queues and launch worker process.""" - self._task_queue = mp.Queue() - self.result_queue = result_handler.result_queue - self.tasks = result_handler.tasks - - self.process: BaseProcess = mp.Process( # type: ignore[attr-defined] - target=_run_worker_process_and_vectorlm_thread, - name="VllmWorkerProcess", - kwargs={ - "worker_factory": worker_factory, - "task_queue": self._task_queue, - "result_queue": self.result_queue, - "vectorlm_fn": vectorlm_fn, - }, - daemon=True, - ) - - self.process.start() - - -def _run_worker_process_and_vectorlm_thread( - worker_factory: Callable[[], Any], - task_queue: Queue, - result_queue: Queue, - vectorlm_fn: Callable[[], None], -) -> None: - """Invoke _run_worker_process and vectorlm logic in separate thread.""" - # Add process-specific prefix to stdout and stderr - - vectorlm_thread = threading.Thread(target=vectorlm_fn) - vectorlm_thread.start() - - _run_worker_process(worker_factory, task_queue, result_queue) From 31464aafe1b7a13d63061537acb833d7960266cf Mon Sep 17 00:00:00 2001 From: jacobthebanana Date: Thu, 23 May 2024 14:20:23 -0400 Subject: [PATCH 81/89] vllm hotswapping [WIP]: Reduced area of vLLM integration interface. Cleanup is required. --- vectorlm/sampling/utils.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/vectorlm/sampling/utils.py b/vectorlm/sampling/utils.py index b4ed496..668375e 100644 --- a/vectorlm/sampling/utils.py +++ b/vectorlm/sampling/utils.py @@ -54,13 +54,18 @@ class SynchronizationBarriers(NamedTuple): class ManagedMultiProcGPUExecutor(MultiprocessingGPUExecutor): - """MultiProcGPUExecutor, but with VectorLM supplied.""" + """MultiProcGPUExecutor, but with VectorLM launched alongside vLLM.""" # only missing parameter in vectorlm_fn is local_rank. vectorlm_fn: Callable[[int], None] def __init__(self, *args, **kwargs) -> None: # noqa: ANN002,ANN003 - """Copy vectorlm_fn into this instance.""" + """Copy class variable vectorlm_fn into this instance. + + Doing so ensures that spawned sub-processes also have access + to vectorlm_fn, which might not otherwise be accessible as a class + variable. + """ self.vectorlm_fn = ManagedMultiProcGPUExecutor.vectorlm_fn super().__init__(*args, **kwargs) @@ -196,7 +201,7 @@ def _wrapped_fn(*args, **kwargs) -> ...: # noqa: ANN002,ANN003 output = [fn(*args, **kwargs)] # fn might access torch.dist, which might conflict with - # broadcast_object_list. Hence, keep all ranks witing until fn returns. + # broadcast_object_list. Hence, keep all ranks witing until fn returns # on rank 0. barriers.before_generation.wait() From 059d57f8636edc535813eb13e0c086720faaa447 Mon Sep 17 00:00:00 2001 From: jacobthebanana Date: Thu, 23 May 2024 20:57:34 -0400 Subject: [PATCH 82/89] vllm hotswapping [WIP]: Refactored vLLM integration interface to minimize changes required in llama_example.py. --- examples/llama_example.py | 70 +++------ examples/llama_example_mp.py | 93 ++--------- vectorlm/sampling/__init__.py | 9 +- vectorlm/sampling/abstract.py | 18 ++- vectorlm/sampling/sampling_lora.py | 35 ++--- vectorlm/sampling/utils.py | 244 ++++++++++++++++++++++------- vectorlm/trainer.py | 8 +- 7 files changed, 251 insertions(+), 226 deletions(-) diff --git a/examples/llama_example.py b/examples/llama_example.py index 210e3f1..c7ce3f8 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -12,10 +12,8 @@ from torch.optim import AdamW from tqdm import tqdm from transformers import set_seed -from vllm import SamplingParams from vectorlm.dataset import Dataset -from vectorlm.sampling import LoRASamplingEngine from vectorlm.trainer import Trainer from vectorlm.utils.data_utils import Config from vectorlm.utils.misc_utils import cleanup, setup, wandb_setup @@ -34,9 +32,7 @@ ) if TYPE_CHECKING: - from vllm import LLM - - from vectorlm.sampling.utils import SynchronizationBarriers + from vectorlm.sampling.utils import AbstractSamplingEngine def parse_args() -> Namespace: @@ -58,50 +54,43 @@ def parse_args() -> Namespace: def main( config: Config, - world_size: int | None = None, - get_vllm_llm: Callable[[], LLM] | None = None, - barriers: SynchronizationBarriers | None = None, - local_rank: int | None = None, + get_sampling_engine: Callable[[], AbstractSamplingEngine] | None = None, ) -> None: """Define the main calling function. + WORLD_SIZE, LOCAL_RANK, and RANK are retrieved from environment vars. + Args: ---- config: vectorlm config, e.g., loaded from yaml - world_size: number of processes. - get_vllm_llm: required only for root process (rank 0). - barriers: SynchronizationBarriers, required for all processes. - local_rank: int, where 0 is root process, one process per accelerator. + get_sampling_engine: optional, blocking function that initializes the + sampling engine. Required if sampling during training is needed. + This method is provided in _VLLMCallbackWrapper. To avoid concurrent + nccl access, be sure to invoke this method before any torch method + that might also access nccl. """ - if barriers is not None: - # Wait until vllm engine is fully initialized. - print(f"rank {local_rank} vllm_init_barrier wait") - barriers.vllm_init.wait() - print(f"rank {local_rank} vllm_init_barrier cleared") + sampling_engine = ( + get_sampling_engine() if get_sampling_engine is not None else None + ) training_args = config.train_parameters - sampler_config = training_args.get("sampler") # set a seed set_seed(training_args.seed) + # set CUDA related dependencies + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + print(f"Rank: {rank}, World size: {world_size}") + if dist.is_initialized(): torch.cuda.set_device(local_rank) torch.cuda.empty_cache() - - # set CUDA related dependencies - if (local_rank is None) or (world_size is None): - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) else: - rank = local_rank # modify if going beyond one node. - os.environ["LOCAL_RANK"] = str(local_rank) - os.environ["RANK"] = str(local_rank) - os.environ["WORLD_SIZE"] = str(world_size) - - print(f"Rank: {rank}, World size: {world_size}") + dist.init_process_group() # setup wandb if rank == 0 and config.enable_wandb_logging: @@ -154,6 +143,9 @@ def main( training_args.low_cpu_mem_usage, is_lora_enabled, ) + # Trigger FSDP initialization before retrieving weights. + # Otherwise FSDP is_root flag might be set incorrectly. + model(input_ids=torch.zeros((1, 1), dtype=torch.int)) # load dataset dataset = Dataset( @@ -190,24 +182,10 @@ def main( dataset, optimizer, lr_scheduler, + sampling_engine, is_peft_adapter_restored, ) - if sampler_config is not None: - # vllm_llm is required only on rank 0. - vllm_llm = ( - get_vllm_llm() - if (get_vllm_llm is not None) and (rank == 0) - else None - ) - sampling_engine = LoRASamplingEngine( - trainer, - vllm_llm, # required only for rank 0 - SamplingParams(seed=0, temperature=0), - barriers, - ) - trainer.sampling_engine = sampling_engine - # Checkpoint check. Always call before training. # If no checkpoint, it returns 0. checkpointed_epoch = trainer.find_checkpoint(training_args.output_dir) diff --git a/examples/llama_example_mp.py b/examples/llama_example_mp.py index cb4fef0..39d9f3f 100644 --- a/examples/llama_example_mp.py +++ b/examples/llama_example_mp.py @@ -19,92 +19,20 @@ from __future__ import annotations import argparse +import os from functools import partial -from typing import Callable from llama_example import main -from vllm.engine.arg_utils import EngineArgs, EngineConfig -from vllm.engine.llm_engine import LLMEngine -from vllm.entrypoints.llm import LLM +from vllm import EngineArgs from vllm.executor.multiproc_worker_utils import ResultHandler, mp -from vectorlm.sampling.utils import ( - ManagedLLM, - ManagedMultiProcGPUExecutor, +from vectorlm.sampling import ( + LoRASamplingEngine, + SamplingEngineProvider, SynchronizationBarriers, ) from vectorlm.utils.data_utils import Config - -class _VLLMCallbackWrapper: - """Provide vLLM Engine access to multiprocess.Process workers. - - vLLM engine is initialized only after the initialize_engine call. - """ - - def __init__( - self, - engine_config: EngineConfig, - vectorlm_config: Config, - world_size: int, - barriers: SynchronizationBarriers, - ) -> None: - """Instantiate class without initializing wrapped vLLM engine.""" - self.llm_engine: LLMEngine | None = None - self.llm: LLM | None = None - self.engine_config = engine_config - self.barriers = barriers - - # Only missing args is local_rank. - self.vectorlm_fn: Callable[[int], None] = partial( - main, - vectorlm_config, - world_size, - self.get_vllm_llm, - self.barriers, - ) - - def initialize_engine(self) -> None: - """Initialize vLLM engine. - - Invoke this method only after vLLM workers are all ready. - """ - ManagedMultiProcGPUExecutor.vectorlm_fn = self.vectorlm_fn - - self.llm_engine = LLMEngine( - **self.engine_config.to_dict(), - executor_class=ManagedMultiProcGPUExecutor, - log_stats=False, - ) - - self.llm = ManagedLLM(self.llm_engine) - print(f"Instantiated ManagedLLM: {self.llm}") - - def get_vllm_llm(self) -> LLM: - """Return LLM instance. - - Invoke this method only within the main (rank 0 driver) process. - """ - assert ( - self.llm is not None - ), "Must finish initialize_engine before starting vectorlm logic." - - llm = self.llm - assert llm is not None - return llm - - def join_vectorlm_thread(self) -> None: - """Join the rank 0 (main process) vectorlm thread. - - Invoke this function only after initialize_engine. - """ - assert self.llm_engine is not None - model_executor = self.llm_engine.model_executor - assert isinstance(model_executor, ManagedMultiProcGPUExecutor) - assert model_executor.driver_worker is not None - model_executor.driver_worker.vectorlm_thread.join() - - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--world_size", type=int, default=1) @@ -124,6 +52,7 @@ def join_vectorlm_thread(self) -> None: dtype=sampler_config.get("vllm_dtype", "auto"), enable_lora=True, ).create_engine_config() + os.environ["WORLD_SIZE"] = str(world_size) # Block all N vectorlm threads until main process finished # initializing vLLM Engine. Additionally, block vectorlm @@ -140,11 +69,11 @@ def join_vectorlm_thread(self) -> None: # rank 0 worker runs in the __main__ process. # all other ranks use one process each. # vectorlm logic in each ranks (including rank 0) is in a separate thread. - vllm_callback_wrapper = _VLLMCallbackWrapper( + vllm_callback_wrapper = SamplingEngineProvider( vllm_engine_config, - vectorlm_config, - world_size, barriers, + LoRASamplingEngine, + partial(main, vectorlm_config), ) vllm_callback_wrapper.initialize_engine() @@ -152,8 +81,4 @@ def join_vectorlm_thread(self) -> None: output = vllm_callback_wrapper.llm.generate("Vector Institute is") print(output[0].prompt + output[0].outputs[0].text) - print("main: vllm_init_barrier waiting") - barriers.vllm_init.wait() - print("main: vllm_init_barrier cleared") - vllm_callback_wrapper.join_vectorlm_thread() diff --git a/vectorlm/sampling/__init__.py b/vectorlm/sampling/__init__.py index 7810e25..ebf9f1c 100644 --- a/vectorlm/sampling/__init__.py +++ b/vectorlm/sampling/__init__.py @@ -1,3 +1,10 @@ from .abstract import AbstractSamplingEngine from .sampling_lora import LoRASamplingEngine -from .utils import handle_sample, multiprocess_wrap +from .utils import ( + ManagedLLM, + ManagedMultiProcGPUExecutor, + SamplingEngineProvider, + SynchronizationBarriers, + handle_sample, + multiprocess_wrap, +) diff --git a/vectorlm/sampling/abstract.py b/vectorlm/sampling/abstract.py index c9f105a..0fd8b56 100644 --- a/vectorlm/sampling/abstract.py +++ b/vectorlm/sampling/abstract.py @@ -8,6 +8,7 @@ import vllm if TYPE_CHECKING: + import torch from vectorlm.trainer import Trainer from .utils import SynchronizationBarriers @@ -18,7 +19,6 @@ class AbstractSamplingEngine(ABC): def __init__( self, - trainer: Trainer, vllm_llm: vllm.LLM | None = None, sampling_params: vllm.SamplingParams | None = None, synchronization_barriers: SynchronizationBarriers | None = None, @@ -26,26 +26,28 @@ def __init__( """Initialize sampling engine. Params: - trainer: Trainer instance. vllm_llm: Instance of vllm.LLM, required only for rank 0. sampling_params: Optionally, specify default sampling params. synchronization_barriers: Optionally, supply barriers to prevent workers from accessing GPU while vLLM is running. """ - self.trainer = trainer self.vllm_llm = vllm_llm self.sampling_params = sampling_params self.synchronization_barriers = synchronization_barriers + self.vllm_train_step = -1 - def update(self, trainer: Trainer | None = None) -> None: - """Inform the sampling engine that the model in trainer is updated. + @abstractmethod + def update(self, model: torch.nn.Module, train_step: int) -> None: + """Update model in sampling engine if the current copy is stale. Params: - trainer: Optionally, replace self.trainer with the provided value. + model: PeftModel, up-to-date model + train_step: int, train step of the given model. """ - if trainer is not None: - self.trainer = trainer + if self.vllm_train_step != train_step: + # Update parameters of self.vllm_llm using the given `model``. + return @abstractmethod def generate( diff --git a/vectorlm/sampling/sampling_lora.py b/vectorlm/sampling/sampling_lora.py index 407e06b..d8716a9 100644 --- a/vectorlm/sampling/sampling_lora.py +++ b/vectorlm/sampling/sampling_lora.py @@ -14,7 +14,7 @@ from .utils import SynchronizationBarriers, multiprocess_wrap if TYPE_CHECKING: - from vectorlm.trainer import Trainer + from peft.peft_model import PeftModel class LoRASamplingEngine(AbstractSamplingEngine): @@ -22,7 +22,6 @@ class LoRASamplingEngine(AbstractSamplingEngine): def __init__( self, - trainer: Trainer, vllm_llm: vllm.LLM | None = None, sampling_params: vllm.SamplingParams | None = None, synchronization_barriers: SynchronizationBarriers | None = None, @@ -31,7 +30,6 @@ def __init__( """Initialize sampling engine. Params: - trainer: Trainer instance. vllm_llm: Instance of vllm.LLM, required only for rank 0. sampling_params: Optionally, specify default sampling params. adapter_temp_folder: Temporary path where temporary adapter weights @@ -62,37 +60,22 @@ def __init__( # placeholder, as the wrapped_fn won't be invoked outside rank-0. generate_fn_raw: Callable[..., list[vllm.RequestOutput]] = ( lambda: None - ) # type: ignore [] + ) # type: ignore[reportAssignmentType] self.generate_fn = multiprocess_wrap(generate_fn_raw, self.barriers) - - # Trigger FSDP initialization before retrieving weights. - # Otherwise FSDP is_root flag might be set incorrectly. - _wrapped_model = trainer.model - assert _wrapped_model is not None - _wrapped_model(input_ids=torch.zeros((1, 1), dtype=torch.int)) self.vllm_train_step = -1 - self.update(trainer) - - def update(self, trainer: Trainer | None = None) -> None: - """Inform the sampling engine that the model in trainer is updated. + def update(self, model: PeftModel, train_step: int) -> None: + """Update model in sampling engine if the current copy is stale. Params: - trainer: Optionally, replace self.trainer with the provided value. + model: PeftModel, up-to-date model + train_step: int, train step of the given model. """ - if trainer is not None: - self.trainer = trainer - - wrapped_model = self.trainer.model - assert wrapped_model is not None - self.barriers.before_generation.wait() - if self.vllm_train_step != self.trainer.tr_step: - save_peft_adapter(wrapped_model, self.adapter_temp_folder) - assert self.trainer.tr_step is not None - assert self.trainer.tr_step >= 0 - self.vllm_train_step = self.trainer.tr_step + if self.vllm_train_step != train_step: + save_peft_adapter(model, self.adapter_temp_folder) + self.vllm_train_step = train_step self.lora_request = LoRARequest( "_vectorlm", self.vllm_train_step + 1, diff --git a/vectorlm/sampling/utils.py b/vectorlm/sampling/utils.py index 668375e..a7fed68 100644 --- a/vectorlm/sampling/utils.py +++ b/vectorlm/sampling/utils.py @@ -3,11 +3,14 @@ from __future__ import annotations import json +import os import threading import time +from functools import partial from typing import TYPE_CHECKING, Any, Callable, Iterable, NamedTuple, TypeVar -from vllm import LLM +from vllm import LLM, LLMEngine, SamplingParams +from vllm.engine.arg_utils import EngineConfig from vllm.executor.multiproc_gpu_executor import MultiprocessingGPUExecutor from vllm.utils import Counter @@ -16,10 +19,12 @@ if TYPE_CHECKING: from threading import Barrier - from vllm import LLMEngine, SamplingParams from vllm.worker.worker_base import WorkerBase +VECTORLM_WORKER_INIT_RDZV_TIMEOUT = 7 + + class SampleOutput(NamedTuple): """Represents possible responses to a prompt. @@ -53,11 +58,74 @@ class SynchronizationBarriers(NamedTuple): after_generation: Barrier +Fn = TypeVar("Fn", bound=Callable[..., Any]) + + +def multiprocess_wrap(fn: Fn, barriers: SynchronizationBarriers) -> Fn: + """Apply barrier to function and broadcast output. + + This wrapper function tries to preserve the type signature + of the wrapped function for the IDE. Tested for Pylance. + + While fn would be invoked only on rank 0, the wrapped function + should be invoked in the vectorlm thread at all ranks, so that + the barriers would block these threads from accessing GPU while + the fn is running. + + Each rank would receive the same value as output. + + Params: + ------- + fn: Function to wrap. Output needs to be compatible with pickle. + barriers: SynchronizationBarriers, only the before_generation and + after_generation barriers are required.. + + Returns + ------- + same output as Fn, but broadcasted to all ranks + (i.e., same value at all ranks) + + """ + + def _wrapped_fn(*args, **kwargs) -> ...: # noqa: ANN002,ANN003 + barriers.after_generation.wait() + + import torch.distributed # type: ignore[reportMissingImports] + + rank = torch.distributed.get_rank() + + # placeholder for output value, + # populate on rank 0 and then broadcast. + # torch requires placeholder element in object list. + output = [None] + if rank == 0: + output = [fn(*args, **kwargs)] + + # fn might access torch.dist, which might conflict with + # broadcast_object_list. Hence, keep all ranks witing until fn returns + # on rank 0. + barriers.before_generation.wait() + + torch.distributed.broadcast_object_list(output) + return output[0] + + return _wrapped_fn # type: ignore[reportReturnType] + + class ManagedMultiProcGPUExecutor(MultiprocessingGPUExecutor): - """MultiProcGPUExecutor, but with VectorLM launched alongside vLLM.""" + """MultiProcGPUExecutor, but with VectorLM launched alongside vLLM. + + This class is compatible as an "executor_class" for the vLLM Engine. + + NCCL requires exactly one process for each GPU, so the vLLM and VectorLM + on each GPU logic need to fit into the same process. + + This class ensures that in each of these one-per-GPU processes, + VectorLM logic would run in a separate thread alongside the vLLM Worker. + """ # only missing parameter in vectorlm_fn is local_rank. - vectorlm_fn: Callable[[int], None] + vectorlm_fn: Callable[[], None] def __init__(self, *args, **kwargs) -> None: # noqa: ANN002,ANN003 """Copy class variable vectorlm_fn into this instance. @@ -75,16 +143,18 @@ def _create_worker( *args, # noqa: ANN002 **kwargs, # noqa: ANN003 ) -> WorkerBase: - """Instantiate worker and launch vectorlm thread. + """Launch vectorlm thread and vLLM worker in the same process. For rank 0, this method is invoked "blocking" inside the rank-0 process. For rank != 0, this method is supposed to be invoked in a child process spawned from the main rank-0 process. """ + os.environ["LOCAL_RANK"] = str(local_rank) + os.environ["RANK"] = str(local_rank) vectorlm_thread = threading.Thread( target=self.vectorlm_fn, - kwargs={"local_rank": local_rank}, + name=f"Rank{local_rank}/vectorlm", ) vectorlm_thread.start() @@ -104,6 +174,114 @@ def __init__(self, llm_engine: LLMEngine) -> None: self.request_counter = Counter() +class SamplingEngineProvider: + """Provide VectorLM workers access to the SamplingEngine via a callback. + + The vLLM VectorLM logic needs to be launched alongside the vLLM worker + This class provides the VectorLM logic + + vLLM engine is initialized only after the initialize_engine call. + """ + + def __init__( + self, + engine_config: EngineConfig, + barriers: SynchronizationBarriers, + sampling_engine_class: AbstractSamplingEngine.__class__, + vectorlm_main_fn: Callable[ + [Callable[[], AbstractSamplingEngine]], + None, + ], + ) -> None: + """Instantiate class without initializing wrapped vLLM engine.""" + self.llm_engine: LLMEngine | None = None + self.llm: LLM | None = None + self.engine_config = engine_config + self.barriers = barriers + self.sampling_engine_class = sampling_engine_class + + # Only missing args is local_rank. + self.vectorlm_fn: Callable[[], None] = partial( + vectorlm_main_fn, + self.get_sampling_engine, + ) + + def initialize_engine(self) -> None: + """Initialize vLLM engine. + + Invoke this method only from the rank 0 __main__. + + This method blocks until all vectorlm threads (including rank 0) + have also reached the vllm_init barrier. + """ + ManagedMultiProcGPUExecutor.vectorlm_fn = self.vectorlm_fn + + self.llm_engine = LLMEngine( + **self.engine_config.to_dict(), + executor_class=ManagedMultiProcGPUExecutor, + log_stats=False, + ) + + self.llm = ManagedLLM(self.llm_engine) + print(f"Instantiated ManagedLLM: {self.llm}") + + thread_name = threading.current_thread().getName() + print(f"{thread_name}: vllm_init waiting") + + try: + self.barriers.vllm_init.wait(VECTORLM_WORKER_INIT_RDZV_TIMEOUT) + except threading.BrokenBarrierError as e: + msg = ( + "SamplingEngineProvider requires get_sampling_engine() to be " + "invoked across all VectorLM ranks (including rank 0) prior " + "to any Torch NCCL logic. \n" + "If sampling engine is not required, " + "please avoid using SamplingEngineProvider, as this provider " + "would launch vLLM and might hang because of concurrent NCCL " + "access. Launch the training script via torchrun instead." + ) + raise RuntimeError(msg) from e + + print(f"{thread_name}: vllm_init cleared") + + def get_sampling_engine(self) -> AbstractSamplingEngine: + """Instantiate sampling engine. + + Invoke this callback method from the VectorLM thread of each process, + including rank 0. + + SamplingEngine handles synchronization and prevents concurrent + NCCL access. Hence, a SamplingEngine instance shall be instantiated + regardless of the rank of the process. + + This method blocks until the vLLM Engine is fully initialized. + """ + thread_name = threading.current_thread().getName() + print(f"{thread_name}: vllm_init wait") + self.barriers.vllm_init.wait() + print(f"{thread_name}: vllm_init cleared") + + # vLLM is instantiated and required only for the rank 0 SamplingEngine. + assert (self.llm is not None) or (int(os.environ["LOCAL_RANK"]) != 0) + return self.sampling_engine_class( + self.llm, + SamplingParams(seed=0, temperature=0), + self.barriers, + ) + + def join_vectorlm_thread(self) -> None: + """Join the rank 0 (main process) vectorlm thread. + + Invoke this function only from __main__ (of rank 0) after + initialize_engine. + """ + assert self.llm_engine is not None + model_executor = self.llm_engine.model_executor + assert isinstance(model_executor, ManagedMultiProcGPUExecutor) + assert model_executor.driver_worker is not None + model_executor.driver_worker.vectorlm_thread.join() + + def handle_sample( sampling_engine: AbstractSamplingEngine, prompts: Iterable[str], @@ -155,57 +333,3 @@ def handle_sample( output_jsonl_file.write("\n".join(jsonl_output_lines) + "\n\n") return sample_outputs - - -Fn = TypeVar("Fn", bound=Callable[..., Any]) - - -def multiprocess_wrap(fn: Fn, barriers: SynchronizationBarriers) -> Fn: - """Apply barrier to function and broadcast output. - - This wrapper function tries to preserve the type signature - of the wrapped function for the IDE. Tested for Pylance. - - While fn would be invoked only on rank 0, the wrapped function - should be invoked in the vectorlm thread at all ranks, so that - the barriers would block these threads from accessing GPU while - the fn is running. - - Each rank would receive the same value as output. - - Params: - ------- - fn: Function to wrap. Output needs to be compatible with pickle. - barriers: SynchronizationBarriers, only the before_generation and - after_generation barriers are required.. - - Returns - ------- - same output as Fn, but broadcasted to all ranks - (i.e., same value at all ranks) - - """ - - def _wrapped_fn(*args, **kwargs) -> ...: # noqa: ANN002,ANN003 - barriers.after_generation.wait() - - import torch.distributed # type: ignore[reportMissingImports] - - rank = torch.distributed.get_rank() - - # placeholder for output value, - # populate on rank 0 and then broadcast. - # torch requires placeholder element in object list. - output = [None] - if rank == 0: - output = [fn(*args, **kwargs)] - - # fn might access torch.dist, which might conflict with - # broadcast_object_list. Hence, keep all ranks witing until fn returns - # on rank 0. - barriers.before_generation.wait() - - torch.distributed.broadcast_object_list(output) - return output[0] - - return _wrapped_fn # type: ignore[reportReturnType] diff --git a/vectorlm/trainer.py b/vectorlm/trainer.py index a03a83c..8cd9661 100644 --- a/vectorlm/trainer.py +++ b/vectorlm/trainer.py @@ -123,6 +123,7 @@ def prepare_trainer( dataset: Dataset, optimizer: Optimizer, lr_scheduler: LRScheduler | ReduceLROnPlateau, + sampling_engine: AbstractSamplingEngine | None = None, is_peft_adapter_restored: bool = False, ) -> None: """Set all essential training requirements. @@ -135,6 +136,9 @@ def prepare_trainer( optimizer: The training optimizer. lr_scheduler: The LR scheduler. + sampling_engine: Optionally, provide a sampling engine to enable + sampling during training. + is_peft_adapter_restored: whether peft is enabled and adapters were restored from filesystem. @@ -145,6 +149,8 @@ def prepare_trainer( self.optimizer = optimizer self.lr_scheduler = lr_scheduler + self.sampling_engine = sampling_engine + self.is_peft_adapter_restored = is_peft_adapter_restored def save_checkpoint(self, epoch: int) -> None: @@ -282,7 +288,7 @@ def step( if (self.sampling_engine is not None) and ( self.tr_step % self.config.sampler.sample_frequency == 0 ): - self.sampling_engine.update(self) + self.sampling_engine.update(self.model, self.tr_step) handle_sample( self.sampling_engine, self.config.sampler.prompts, From b5c6389403f628c7030f67053f99915ae4bb5a15 Mon Sep 17 00:00:00 2001 From: jacobthebanana Date: Thu, 23 May 2024 21:00:21 -0400 Subject: [PATCH 83/89] vllm hotswapping [WIP]: deleted unneded torch dist.barrier from llama_example.py. --- examples/llama_example.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/llama_example.py b/examples/llama_example.py index c7ce3f8..eeb1b24 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -199,7 +199,6 @@ def main( ): batch = next(train_dl_iterator) trainer.step(batch, epoch) - dist.barrier() if epoch == training_args.epochs - 1: hf_save_dir = os.path.join(training_args.output_dir, "final-model") From f5068128a3f2454f4a6012436bb16ac6a81632a3 Mon Sep 17 00:00:00 2001 From: jacobthebanana Date: Fri, 24 May 2024 09:06:51 -0400 Subject: [PATCH 84/89] vllm hotswapping [WIP]: documentation fixes and cleanup. --- examples/llama_example.py | 3 +- examples/llama_example_mp.py | 6 +- examples/train_and_inference.py | 167 -------------------------------- vectorlm/sampling/abstract.py | 2 +- vectorlm/sampling/utils.py | 2 +- 5 files changed, 6 insertions(+), 174 deletions(-) delete mode 100644 examples/train_and_inference.py diff --git a/examples/llama_example.py b/examples/llama_example.py index eeb1b24..ee1dde5 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -89,12 +89,11 @@ def main( if dist.is_initialized(): torch.cuda.set_device(local_rank) torch.cuda.empty_cache() - else: - dist.init_process_group() # setup wandb if rank == 0 and config.enable_wandb_logging: wandb_setup(config, **config.wandb_config) + dist.barrier() # load model and tokenizer model, tokenizer = load_model_and_tokenizer( diff --git a/examples/llama_example_mp.py b/examples/llama_example_mp.py index 39d9f3f..f73083d 100644 --- a/examples/llama_example_mp.py +++ b/examples/llama_example_mp.py @@ -12,8 +12,7 @@ feature (e.g., a Barrier shared across all processes) that the rank 0 process can remotely unblock. -Edit: It seems that vllm.entrypoint.llm.LLM generate calls aren't -entirely blocking. +See https://docs.google.com/presentation/d/1FCa5O8RYYkRRCAAcXhqCvomePo5fEfhjQciSteTEJ30 """ from __future__ import annotations @@ -68,7 +67,8 @@ # rank 0 worker runs in the __main__ process. # all other ranks use one process each. - # vectorlm logic in each ranks (including rank 0) is in a separate thread. + # vectorlm logic in each ranks (including rank 0) is in a separate thread + # from the vLLM worker logic. vllm_callback_wrapper = SamplingEngineProvider( vllm_engine_config, barriers, diff --git a/examples/train_and_inference.py b/examples/train_and_inference.py deleted file mode 100644 index 305cd5a..0000000 --- a/examples/train_and_inference.py +++ /dev/null @@ -1,167 +0,0 @@ -from __future__ import annotations - -import argparse -import math -import os -import sys -from argparse import Namespace - -import torch -import torch.distributed as dist -from torch.optim import AdamW -from tqdm import tqdm -from transformers import set_seed - -from vectorlm.dataset import Dataset -from vectorlm.trainer import Trainer -from vectorlm.utils.data_utils import Config -from vectorlm.utils.misc_utils import cleanup, setup, wandb_setup -from vectorlm.utils.model_utils import ( - get_lora_model_from_base_model, - get_submodule_by_pattern, - load_model_and_tokenizer, - shard_model, -) -from vectorlm.utils.optimizer_utils import get_custom_scheduler -from vectorlm.utils.save_utils import save_consolidated_model - - -def parse_args() -> Namespace: - """Parse command-line arguments. - - Returns - ------- - The parsed arguments. - - """ - parser = argparse.ArgumentParser() - parser.add_argument( - "--yaml_path", - default="configs/config.yaml", - required=False, - ) - return parser.parse_args() - - -def main(config: Config) -> None: - """Define the main calling function.""" - training_args = config.train_parameters - - # set a seed - set_seed(training_args.seed) - - # set CUDA related dependencies - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - - print(f"Rank: {rank}, World size: {world_size}") - if dist.is_initialized(): - torch.cuda.set_device(local_rank) - torch.cuda.empty_cache() - - # setup wandb - if rank == 0: - wandb_setup(config, **config.wandb_config) - dist.barrier() - - # load model and tokenizer - model, tokenizer = load_model_and_tokenizer( - config.model, - training_args.use_mp, - training_args.use_flash_attention, - training_args.max_seq_len, - local_rank, - training_args.low_cpu_mem_usage, - ) - - lora_peft_config = getattr( - config.train_parameters, - "lora_peft_config", - None, - ) - if lora_peft_config is not None: - model = get_lora_model_from_base_model(model, lora_peft_config) - - decoder_layer_module = get_submodule_by_pattern(model, r"DecoderLayer$") - model = shard_model( - model.bfloat16(), - decoder_layer_module, - training_args.use_mp, - training_args.use_activation_checkpointing, - training_args.sharding_strategy, - local_rank, - training_args.low_cpu_mem_usage, - ) - - # load dataset - dataset = Dataset( - config=config.dataset, - tokenizer=tokenizer, - ) - - # instantiate trainer - trainer = Trainer( - config=training_args, - enable_wandb_logging=config.enable_wandb_logging, - original_dataset_length=dataset.original_length, - ) - - # load optimizer - optimizer = AdamW( - model.parameters(), - **training_args.optimizer, - ) - - # load lr scheduler - lr_scheduler = get_custom_scheduler( - training_args.lr_scheduler_type, - optimizer, - math.ceil( - trainer.num_update_steps_per_epoch * training_args.warmup_ratio, - ), - trainer.max_steps, - ) - - trainer.prepare_trainer( - model, - tokenizer, - dataset, - optimizer, - lr_scheduler, - ) - - # Checkpoint check. Always call before training. - # If no checkpoint, it returns 0. - checkpointed_epoch = trainer.find_checkpoint(training_args.output_dir) - - for epoch in range(checkpointed_epoch, training_args.epochs): - trainer.model.train() - train_dl_iterator = iter(dataset.train_dataloader) - for _ in tqdm( - range(len(dataset.train_dataloader)), - disable=rank != 0, - file=sys.__stdout__, - ): - batch = next(train_dl_iterator) - trainer.step(batch, epoch) - - if epoch == training_args.epochs - 1: - hf_save_dir = os.path.join(training_args.output_dir, "final-model") - else: - hf_save_dir = os.path.join( - training_args.output_dir, - "checkpoints", - f"epoch_{epoch}", - "end-epoch-model", - ) - save_consolidated_model(trainer.model, hf_save_dir, rank) - dataset.reset_dataloaders() - - -if __name__ == "__main__": - args = parse_args() - config = Config(yaml_path=args.yaml_path) - setup(config.train_parameters.output_dir) - main(config) - cleanup() diff --git a/vectorlm/sampling/abstract.py b/vectorlm/sampling/abstract.py index 0fd8b56..c2f3cdd 100644 --- a/vectorlm/sampling/abstract.py +++ b/vectorlm/sampling/abstract.py @@ -1,4 +1,4 @@ -"""Wrapper around sampling engine.""" +"""Wrapper around vLLM. Also handles synchronization.""" from __future__ import annotations diff --git a/vectorlm/sampling/utils.py b/vectorlm/sampling/utils.py index a7fed68..c0ce793 100644 --- a/vectorlm/sampling/utils.py +++ b/vectorlm/sampling/utils.py @@ -118,7 +118,7 @@ class ManagedMultiProcGPUExecutor(MultiprocessingGPUExecutor): This class is compatible as an "executor_class" for the vLLM Engine. NCCL requires exactly one process for each GPU, so the vLLM and VectorLM - on each GPU logic need to fit into the same process. + logic on each GPU need to fit into the same process. This class ensures that in each of these one-per-GPU processes, VectorLM logic would run in a separate thread alongside the vLLM Worker. From 3e27e8476274d94cb06223cdd5680eda020fe454 Mon Sep 17 00:00:00 2001 From: jacobthebanana Date: Fri, 24 May 2024 09:17:14 -0400 Subject: [PATCH 85/89] vllm hotswapping [WIP]: cleaned up documentation related to multiprocess_wrap. --- vectorlm/sampling/sampling_lora.py | 4 +--- vectorlm/sampling/utils.py | 8 ++++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/vectorlm/sampling/sampling_lora.py b/vectorlm/sampling/sampling_lora.py index d8716a9..934ae9a 100644 --- a/vectorlm/sampling/sampling_lora.py +++ b/vectorlm/sampling/sampling_lora.py @@ -58,9 +58,7 @@ def __init__( generate_fn_raw = vllm_llm.generate else: # placeholder, as the wrapped_fn won't be invoked outside rank-0. - generate_fn_raw: Callable[..., list[vllm.RequestOutput]] = ( - lambda: None - ) # type: ignore[reportAssignmentType] + generate_fn_raw = None self.generate_fn = multiprocess_wrap(generate_fn_raw, self.barriers) self.vllm_train_step = -1 diff --git a/vectorlm/sampling/utils.py b/vectorlm/sampling/utils.py index c0ce793..df233c3 100644 --- a/vectorlm/sampling/utils.py +++ b/vectorlm/sampling/utils.py @@ -61,7 +61,7 @@ class SynchronizationBarriers(NamedTuple): Fn = TypeVar("Fn", bound=Callable[..., Any]) -def multiprocess_wrap(fn: Fn, barriers: SynchronizationBarriers) -> Fn: +def multiprocess_wrap(fn: Fn | None, barriers: SynchronizationBarriers) -> Fn: """Apply barrier to function and broadcast output. This wrapper function tries to preserve the type signature @@ -70,13 +70,16 @@ def multiprocess_wrap(fn: Fn, barriers: SynchronizationBarriers) -> Fn: While fn would be invoked only on rank 0, the wrapped function should be invoked in the vectorlm thread at all ranks, so that the barriers would block these threads from accessing GPU while - the fn is running. + the fn is running. However, only the fn specified at rank 0 + will actually be executed. The fn parameter can be None at all + other ranks. Each rank would receive the same value as output. Params: ------- fn: Function to wrap. Output needs to be compatible with pickle. + This arg is required only on rank 0. barriers: SynchronizationBarriers, only the before_generation and after_generation barriers are required.. @@ -99,6 +102,7 @@ def _wrapped_fn(*args, **kwargs) -> ...: # noqa: ANN002,ANN003 # torch requires placeholder element in object list. output = [None] if rank == 0: + assert fn is not None output = [fn(*args, **kwargs)] # fn might access torch.dist, which might conflict with From 879399f053a04150b32b07198343d589585f9bc6 Mon Sep 17 00:00:00 2001 From: jacobthebanana Date: Fri, 24 May 2024 09:22:29 -0400 Subject: [PATCH 86/89] vllm hotswapping [WIP]: cleaned up changes in llama_example.py. --- examples/llama_example.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/llama_example.py b/examples/llama_example.py index ee1dde5..abe76b4 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -85,7 +85,6 @@ def main( world_size = int(os.environ["WORLD_SIZE"]) print(f"Rank: {rank}, World size: {world_size}") - if dist.is_initialized(): torch.cuda.set_device(local_rank) torch.cuda.empty_cache() From bc0ae52602feef4813c9cfb59cd93d678230322d Mon Sep 17 00:00:00 2001 From: jacobthebanana Date: Fri, 24 May 2024 10:10:41 -0400 Subject: [PATCH 87/89] vllm hotswapping [WIP]: added example gemma sampling config. --- configs/config_gemma.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/config_gemma.yaml b/configs/config_gemma.yaml index 2e9d16e..fb2b6fa 100644 --- a/configs/config_gemma.yaml +++ b/configs/config_gemma.yaml @@ -7,7 +7,7 @@ wandb_config: # tags: ["20240418-1a-preemption"] train_parameters: - output_dir: /network/scratch/j/jacob-junqi.tian/vectorlm/weights + output_dir: weights max_seq_len: 128 epochs: 10 seed: 11 From 5e8944d0fe27ce1ed6b2405c2202179391bd6cbe Mon Sep 17 00:00:00 2001 From: jacobthebanana Date: Mon, 17 Jun 2024 20:51:23 -0400 Subject: [PATCH 88/89] vllm hotswapping: Refactoring and cleanup. --- configs/config_gemma.yaml | 68 ------------------- docs/config.md | 2 +- examples/__init__.py | 0 examples/llama_example.py | 19 ++++++ ..._example_mp.py => lora_hotswap_example.py} | 5 +- 5 files changed, 23 insertions(+), 71 deletions(-) delete mode 100644 configs/config_gemma.yaml delete mode 100644 examples/__init__.py rename examples/{llama_example_mp.py => lora_hotswap_example.py} (94%) diff --git a/configs/config_gemma.yaml b/configs/config_gemma.yaml deleted file mode 100644 index fb2b6fa..0000000 --- a/configs/config_gemma.yaml +++ /dev/null @@ -1,68 +0,0 @@ -model: google/gemma-2b -enable_wandb_logging: False - -wandb_config: - project: vector-lm-verify - name: benchmark-lora - # tags: ["20240418-1a-preemption"] - -train_parameters: - output_dir: weights - max_seq_len: 128 - epochs: 10 - seed: 11 - - # Sharding strategy - sharding_strategy: FULL_SHARD - - # Memory - use_mp: True - use_activation_checkpointing: True - # use_flash_attention is automatically enabled - # for CUDA capability > 8.0 - use_flash_attention: False - low_cpu_mem_usage: True - - lora_peft_config: - task_type: CAUSAL_LM - inference_mode: False - r: 8 - lora_alpha: 32 - lora_dropout: 0.1 - - # Gradient norm clipping - max_grad_norm: 1 - gradient_accumulation_steps: 4 - - # Optimizer - optimizer: - lr: 1.0e-4 - weight_decay: 0.1 - betas: [0.9, 0.95] - eps: 1.0e-5 - - # Scheduler - lr_scheduler_type: cosine - warmup_ratio: 0.05 - - # Checkpointing - checkpointing_enabled: False - logging_steps: 10 - save_frequency: 0.10 - - # Sampling during training - sampler: - sample_frequency: 8 - output_jsonl_path: data/output-5e-5-2b.jsonl - vllm_dtype: half - prompts: - - "Vector Institute of the" - - "Vector Institute is located in the city of" - - "The answer to life the universe and everything is" - -dataset: - ignore_index: -100 - eval_bs: 8 - train_bs: 8 - train_ds: data/processed/vector-west/train - eval_ds: data/processed/vector-west/test diff --git a/docs/config.md b/docs/config.md index 2951d15..0890a7a 100644 --- a/docs/config.md +++ b/docs/config.md @@ -54,7 +54,7 @@ Similar to the wandb config above, these keyword parameters are fed directly int ### Sampling during Training -To disable sampling during training, delete the entire "sampling" section. +To disable sampling during training, comment out the entire "sampling" section. * `sample_frequency`: Number of train steps between two consecutive sampling steps. * `output_jsonl_path`: Optional; write sampled output to the specified jsonl file. diff --git a/examples/__init__.py b/examples/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/examples/llama_example.py b/examples/llama_example.py index abe76b4..4bd7017 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -35,6 +35,23 @@ from vectorlm.sampling.utils import AbstractSamplingEngine +SAMPLER_NOT_PROVIDED_ERROR_MSG = """ +Hot-swap sampling is enabled but sampler engine is not provided. \ +Did you launch this script via `torchrun llama_example.py`? \ +To enable hotswap vLLM sampling during training, launch the \ +training script via `python3 lora_hotswap_example.py` directly \ +without using Torchrun, especially when running in multi-GPU environments. \ + +Custom logic in lora_hotswap_example are required to handles multi-GPU \ +synchronization and prevent NCCL conflicts with vLLM Engine when running \ +in multi-GPU setups. \ + +If you have renamed llama_example.py, be sure to adjust the import in \ +lora_hotswap_example.py to load the correct `main` function for the training \ +loop. +""" + + def parse_args() -> Namespace: """Parse command-line arguments. @@ -73,6 +90,8 @@ def main( sampling_engine = ( get_sampling_engine() if get_sampling_engine is not None else None ) + if config.train_parameters.get("sampler") is not None: + assert sampling_engine is not None, SAMPLER_NOT_PROVIDED_ERROR_MSG training_args = config.train_parameters diff --git a/examples/llama_example_mp.py b/examples/lora_hotswap_example.py similarity index 94% rename from examples/llama_example_mp.py rename to examples/lora_hotswap_example.py index f73083d..55c39c7 100644 --- a/examples/llama_example_mp.py +++ b/examples/lora_hotswap_example.py @@ -1,4 +1,4 @@ -"""llama_example, but uses multiprocessing in place of torchrun. +"""Supply LoRASamplingEngine to llama_example. Each non-rank-0 worker process should spawn vectorlm logic in a separate thread (but same process) but won't run the actual @@ -12,7 +12,8 @@ feature (e.g., a Barrier shared across all processes) that the rank 0 process can remotely unblock. -See https://docs.google.com/presentation/d/1FCa5O8RYYkRRCAAcXhqCvomePo5fEfhjQciSteTEJ30 +See docs.google.com/presentation/d/1FCa5O8RYYkRRCAAcXhqCvomePo5fEfhjQciSteTEJ30 +for more detail on this architecture. """ from __future__ import annotations From 2005a7dc7a11c5afa04a4b7ca2e51486cc52a4c7 Mon Sep 17 00:00:00 2001 From: jacobthebanana Date: Tue, 18 Jun 2024 10:41:53 -0400 Subject: [PATCH 89/89] vllm hotswapping: Moved Sampler import into conditional block to avoid importing vLLM when not required. Ruff formatting fixes. --- vectorlm/sampling/abstract.py | 1 - vectorlm/sampling/sampling_lora.py | 3 +-- vectorlm/trainer.py | 5 +++-- vectorlm/utils/misc_utils.py | 2 +- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/vectorlm/sampling/abstract.py b/vectorlm/sampling/abstract.py index c2f3cdd..0747722 100644 --- a/vectorlm/sampling/abstract.py +++ b/vectorlm/sampling/abstract.py @@ -9,7 +9,6 @@ if TYPE_CHECKING: import torch - from vectorlm.trainer import Trainer from .utils import SynchronizationBarriers diff --git a/vectorlm/sampling/sampling_lora.py b/vectorlm/sampling/sampling_lora.py index 934ae9a..32dfbcc 100644 --- a/vectorlm/sampling/sampling_lora.py +++ b/vectorlm/sampling/sampling_lora.py @@ -1,9 +1,8 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING -import torch import torch.distributed as dist import vllm from vllm.lora.request import LoRARequest diff --git a/vectorlm/trainer.py b/vectorlm/trainer.py index 8cd9661..c571bf7 100644 --- a/vectorlm/trainer.py +++ b/vectorlm/trainer.py @@ -7,13 +7,12 @@ import peft import torch import torch.distributed as dist +import wandb from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler, ReduceLROnPlateau from transformers import PreTrainedTokenizer -import wandb from vectorlm.dataset import Dataset -from vectorlm.sampling import handle_sample from vectorlm.utils.data_utils import Config from vectorlm.utils.save_utils import ( checkpoint_exists, @@ -288,6 +287,8 @@ def step( if (self.sampling_engine is not None) and ( self.tr_step % self.config.sampler.sample_frequency == 0 ): + from vectorlm.sampling import handle_sample + self.sampling_engine.update(self.model, self.tr_step) handle_sample( self.sampling_engine, diff --git a/vectorlm/utils/misc_utils.py b/vectorlm/utils/misc_utils.py index 1c59b98..081deb2 100644 --- a/vectorlm/utils/misc_utils.py +++ b/vectorlm/utils/misc_utils.py @@ -4,8 +4,8 @@ from typing import Any import torch.distributed as dist - import wandb + from vectorlm.utils.data_utils import Config