diff --git a/.gitignore b/.gitignore index 05c42cc..f4c948d 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ data/ **/*.pyc /.cache /.vscode -/data \ No newline at end of file +/data +/env diff --git a/configs/config.yaml b/configs/config.yaml index a2a03e7..6758a2b 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -21,7 +21,7 @@ train_parameters: low_cpu_mem_usage: True # LoRA config: uncomment the block below to enable LoRA - + # lora_peft_config: # task_type: CAUSAL_LM # inference_mode: False @@ -29,7 +29,6 @@ train_parameters: # lora_alpha: 32 # lora_dropout: 0.1 - # Gradient norm clipping max_grad_norm: 1 gradient_accumulation_steps: 4 @@ -50,6 +49,17 @@ train_parameters: logging_steps: 500 save_frequency: 0.25 + # Sampling during training + # Uncomment the block below to enable. + + # sampler: + # sample_frequency: 8 + # output_jsonl_path: data/output.jsonl + # prompts: + # - "Vector Institute of the" + # - "Vector Institute is located in the city of" + # - "The answer to life the universe and everything is" + dataset: ignore_index: -100 eval_bs: 8 diff --git a/docs/config.md b/docs/config.md index c48592d..0890a7a 100644 --- a/docs/config.md +++ b/docs/config.md @@ -51,6 +51,22 @@ Similar to the wandb config above, these keyword parameters are fed directly int * `logging_steps`: How often evaluation is run using the evaluation dataset. * `save_frequency`: The frequency at which checkpointing occurs. This must be between 0 and 1. + +### Sampling during Training + +To disable sampling during training, comment out the entire "sampling" section. + +* `sample_frequency`: Number of train steps between two consecutive sampling steps. +* `output_jsonl_path`: Optional; write sampled output to the specified jsonl file. +* `prompts`: YAML list of prompt strings. + +Each line of the output jsonl file would be a dictionary with keys: + +* `tr_step`: number (integer), trainer step when this line was generated. +* `prompt`: string. +* `options`: list of strings, one for each possible option that the sampler provided. +* `time_taken`: float, number of seconds taken to generate **all** prompts at this step. + ## Dataset * `ignore_index`: The integer index used to ignore a given token in the loss calculation. Cross-entropy loss by default uses `-100`. diff --git a/docs/sampling.md b/docs/sampling.md new file mode 100644 index 0000000..90cf4c5 --- /dev/null +++ b/docs/sampling.md @@ -0,0 +1,26 @@ +# Efficient Sampling during training + +Some training objectives, noteably PPO, require "sampling" from the language model many times during training. The most straightforward approach might be to invoke model.generate on the model from within the training loop. Nevertheless, there have been a number of alternative inference approaches, including vLLM and others, promising over 10x the sampling throughput in terms of tokens generated per second when using a large sampling batch size. If model.generate is taking up too much of the training time, it might be worthwhile looking into these third-party solutions for speeding up the sampling process. + +One main challenge of running these third-party solutions, however, is that most of them assume that the weights of the language model are fixed, such that there isn't a straightforward way of updating these weights. Usually, updating the weights requires restarting the sampling engine, which sometimes take minutes. At the same time, the performance of PPO and similar techniques heavily rely on the ability to replace the weights efficiently, or else the training would no longer be on-policy and convergence would take substantially more training steps. To resolve this issue, we implemented techniques to "hot-swap" the model parameters that are used in the sampling process. + +Additionally, it is not straightforward to ensure a consistently high GPU utilization when combining sampling with training. +This repository enables you to make the most out of all your GPUs by fitting vLLM and your training loop into the same set of devices. This way, none of the GPUs would sit idle- if a GPU is not running training, it would be busy sampling using vLLM. These slides ([link](https://docs.google.com/presentation/d/1FCa5O8RYYkRRCAAcXhqCvomePo5fEfhjQciSteTEJ30/edit?usp=sharing)) provide an overview of the architecture behind this approach. + +## Example- Supervised fine-tuning + +We provide a basic example that samples from the language model while fine-tuning using a basic causal language modelling objective. To run the example, uncomment the "sampler" section in your configuration yaml, choose a port for `nccl` coordination, and run the following command (not using torchrun): + +``` +export MASTER_ADDR=127.0.0.1 +export MASTER_PORT=19132 +python3 examples/llama_example_mp.py \ +--yaml_path configs/config.yaml \ +--world_size 2 +``` + +## Bring your own training loop + +While the reference implementation is only for supervised fine-tuning, we provide abstractions that make it easier for you to implement your own training loop- be it PPO RLHF, TWIST, or something else. The goal is to abstract away all the synchronization logic, so that a training loop you've built on one GPU could scale to multiple GPUs on the same server with minimal modifications. + +To get started, refer to examples/llama_example.py and vectorlm/trainer.py. Usually, the vLLM Engine is accessible only from the rank 0, making synchronization challenging. When invoked through llama_example_mp, the `SamplingEngine` interface in VectorLM enables your training loop to access vLLM.LLM.generate from all ranks, returning the same result across all ranks. Note that because the synchronization barriers require all ranks to reach the synchronization point, you need to invoke `generate` from all ranks. diff --git a/examples/launch_lora.sh b/examples/launch_lora.sh new file mode 100644 index 0000000..250daef --- /dev/null +++ b/examples/launch_lora.sh @@ -0,0 +1,26 @@ +#!/bin/bash +#SBATCH --job-name=llama7b-2 +#SBATCH --nodes=1 +#SBATCH --mem=0 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-gpu=6 +#SBATCH --gres=gpu:4 +#SBATCH --output=llama-2-7b.%j.out +#SBATCH --error=llama-2-7b.%j.err +#SBATCH --partition=a100 +#SBATCH --qos=your_assigned_qos # CHANGE +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=3-00 + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export NCCL_DEBUG=WARN +export NCCL_DEBUG_SUBSYS=WARN + +# export TORCH_DISTRIBUTED_DEBUG=DETAIL # Uncomment these flags for debugging communication +# export TORCH_CPP_LOG_LEVEL=INFO +export LOGLEVEL=INFO +export PYTHONFAULTHANDLER=1 +# export CUDA_LAUNCH_BLOCKING=0 + +torchrun --nnodes=1 --nproc-per-node=${SLURM_GPUS_ON_NODE} example_lora.py --yaml_path configs/config-lora.yaml diff --git a/examples/launch_lora_one_gpu.sh b/examples/launch_lora_one_gpu.sh new file mode 100644 index 0000000..b030854 --- /dev/null +++ b/examples/launch_lora_one_gpu.sh @@ -0,0 +1,26 @@ +#!/bin/bash +#SBATCH --job-name=llama7b-2-lora +#SBATCH --nodes=1 +#SBATCH --mem=32GB +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-gpu=6 +#SBATCH --gres=gpu:1 +#SBATCH --output=llama-2-7b-lora.%j.out +#SBATCH --error=llama-2-7b-lora.%j.err +#SBATCH --partition=a100 +#SBATCH --qos=your_assigned_qos # CHANGE +#SBATCH --open-mode=append +#SBATCH --wait-all-nodes=1 +#SBATCH --time=3-00 + +export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag. +export NCCL_DEBUG=WARN +export NCCL_DEBUG_SUBSYS=WARN + +# export TORCH_DISTRIBUTED_DEBUG=DETAIL # Uncomment these flags for debugging communication +# export TORCH_CPP_LOG_LEVEL=INFO +export LOGLEVEL=INFO +export PYTHONFAULTHANDLER=1 +# export CUDA_LAUNCH_BLOCKING=0 + +torchrun --nnodes=1 --nproc-per-node=1 example_lora.py --yaml_path configs/config-lora.yaml diff --git a/examples/llama_example.py b/examples/llama_example.py index b9cdb40..4bd7017 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -5,6 +5,7 @@ import os import sys from argparse import Namespace +from typing import TYPE_CHECKING, Callable import torch import torch.distributed as dist @@ -30,6 +31,26 @@ save_peft_adapter, ) +if TYPE_CHECKING: + from vectorlm.sampling.utils import AbstractSamplingEngine + + +SAMPLER_NOT_PROVIDED_ERROR_MSG = """ +Hot-swap sampling is enabled but sampler engine is not provided. \ +Did you launch this script via `torchrun llama_example.py`? \ +To enable hotswap vLLM sampling during training, launch the \ +training script via `python3 lora_hotswap_example.py` directly \ +without using Torchrun, especially when running in multi-GPU environments. \ + +Custom logic in lora_hotswap_example are required to handles multi-GPU \ +synchronization and prevent NCCL conflicts with vLLM Engine when running \ +in multi-GPU setups. \ + +If you have renamed llama_example.py, be sure to adjust the import in \ +lora_hotswap_example.py to load the correct `main` function for the training \ +loop. +""" + def parse_args() -> Namespace: """Parse command-line arguments. @@ -48,8 +69,30 @@ def parse_args() -> Namespace: return parser.parse_args() -def main(config: Config) -> None: - """Define the main calling function.""" +def main( + config: Config, + get_sampling_engine: Callable[[], AbstractSamplingEngine] | None = None, +) -> None: + """Define the main calling function. + + WORLD_SIZE, LOCAL_RANK, and RANK are retrieved from environment vars. + + Args: + ---- + config: vectorlm config, e.g., loaded from yaml + get_sampling_engine: optional, blocking function that initializes the + sampling engine. Required if sampling during training is needed. + This method is provided in _VLLMCallbackWrapper. To avoid concurrent + nccl access, be sure to invoke this method before any torch method + that might also access nccl. + + """ + sampling_engine = ( + get_sampling_engine() if get_sampling_engine is not None else None + ) + if config.train_parameters.get("sampler") is not None: + assert sampling_engine is not None, SAMPLER_NOT_PROVIDED_ERROR_MSG + training_args = config.train_parameters # set a seed @@ -66,7 +109,7 @@ def main(config: Config) -> None: torch.cuda.empty_cache() # setup wandb - if rank == 0: + if rank == 0 and config.enable_wandb_logging: wandb_setup(config, **config.wandb_config) dist.barrier() @@ -87,7 +130,9 @@ def main(config: Config) -> None: is_lora_enabled = True peft_adapter_path = None # Restore peft adapter from filesystem if available. - if checkpoint_exists(training_args.output_dir): + if (training_args.checkpointing_enabled) and checkpoint_exists( + training_args.output_dir, + ): peft_adapter_path = os.path.join( training_args.output_dir, "checkpoints", @@ -115,6 +160,9 @@ def main(config: Config) -> None: training_args.low_cpu_mem_usage, is_lora_enabled, ) + # Trigger FSDP initialization before retrieving weights. + # Otherwise FSDP is_root flag might be set incorrectly. + model(input_ids=torch.zeros((1, 1), dtype=torch.int)) # load dataset dataset = Dataset( @@ -151,6 +199,7 @@ def main(config: Config) -> None: dataset, optimizer, lr_scheduler, + sampling_engine, is_peft_adapter_restored, ) @@ -159,7 +208,6 @@ def main(config: Config) -> None: checkpointed_epoch = trainer.find_checkpoint(training_args.output_dir) for epoch in range(checkpointed_epoch, training_args.epochs): - trainer.model.train() train_dl_iterator = iter(dataset.train_dataloader) for _ in tqdm( range(len(dataset.train_dataloader)), @@ -185,6 +233,8 @@ def main(config: Config) -> None: save_consolidated_model(trainer.model, hf_save_dir, rank) dataset.reset_dataloaders() + sys.exit(0) + if __name__ == "__main__": args = parse_args() diff --git a/examples/lora_hotswap_example.py b/examples/lora_hotswap_example.py new file mode 100644 index 0000000..55c39c7 --- /dev/null +++ b/examples/lora_hotswap_example.py @@ -0,0 +1,85 @@ +"""Supply LoRASamplingEngine to llama_example. + +Each non-rank-0 worker process should spawn vectorlm logic in a +separate thread (but same process) but won't run the actual +vectorlm logic until the vLLM Engine is initialized- inference +weights loaded into each worker. + +To do so without duplicating vLLM code, observe that only the main process +(rank 0) is aware that vLLM engine was initialized properly +(when LLMEngine.__init__ returns.) Hence, one way to implement this +setup is to block the vectorlm thread with a multiprocessing synchronization +feature (e.g., a Barrier shared across all processes) that the rank 0 process +can remotely unblock. + +See docs.google.com/presentation/d/1FCa5O8RYYkRRCAAcXhqCvomePo5fEfhjQciSteTEJ30 +for more detail on this architecture. +""" + +from __future__ import annotations + +import argparse +import os +from functools import partial + +from llama_example import main +from vllm import EngineArgs +from vllm.executor.multiproc_worker_utils import ResultHandler, mp + +from vectorlm.sampling import ( + LoRASamplingEngine, + SamplingEngineProvider, + SynchronizationBarriers, +) +from vectorlm.utils.data_utils import Config + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--world_size", type=int, default=1) + parser.add_argument("--yaml_path", type=str, required=True) + args = parser.parse_args() + + world_size: int = args.world_size + vectorlm_config = Config(yaml_path=args.yaml_path) + sampler_config = vectorlm_config.train_parameters.sampler # type: ignore[reportAttributeAccessIssue] + vllm_engine_config = EngineArgs( + model=vectorlm_config.model, # type: ignore[reportAttributeAccessIssue] + gpu_memory_utilization=sampler_config.get( + "gpu_memory_utilization", + 0.35, + ), + tensor_parallel_size=world_size, + dtype=sampler_config.get("vllm_dtype", "auto"), + enable_lora=True, + ).create_engine_config() + os.environ["WORLD_SIZE"] = str(world_size) + + # Block all N vectorlm threads until main process finished + # initializing vLLM Engine. Additionally, block vectorlm + # threads as long as vLLM tasks are running. + barriers = SynchronizationBarriers( + # (n+1) threads: __main__, and n vectorlm threads (including main). + vllm_init=mp.Barrier(world_size + 1), + # n vectorlm threads. + before_generation=mp.Barrier(world_size), + after_generation=mp.Barrier(world_size), + ) + vllm_result_handler = ResultHandler() + + # rank 0 worker runs in the __main__ process. + # all other ranks use one process each. + # vectorlm logic in each ranks (including rank 0) is in a separate thread + # from the vLLM worker logic. + vllm_callback_wrapper = SamplingEngineProvider( + vllm_engine_config, + barriers, + LoRASamplingEngine, + partial(main, vectorlm_config), + ) + + vllm_callback_wrapper.initialize_engine() + assert vllm_callback_wrapper.llm is not None + output = vllm_callback_wrapper.llm.generate("Vector Institute is") + print(output[0].prompt + output[0].outputs[0].text) + + vllm_callback_wrapper.join_vectorlm_thread() diff --git a/profiling/__init__.py b/profiling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vectorlm/sampling/__init__.py b/vectorlm/sampling/__init__.py new file mode 100644 index 0000000..ebf9f1c --- /dev/null +++ b/vectorlm/sampling/__init__.py @@ -0,0 +1,10 @@ +from .abstract import AbstractSamplingEngine +from .sampling_lora import LoRASamplingEngine +from .utils import ( + ManagedLLM, + ManagedMultiProcGPUExecutor, + SamplingEngineProvider, + SynchronizationBarriers, + handle_sample, + multiprocess_wrap, +) diff --git a/vectorlm/sampling/abstract.py b/vectorlm/sampling/abstract.py new file mode 100644 index 0000000..0747722 --- /dev/null +++ b/vectorlm/sampling/abstract.py @@ -0,0 +1,71 @@ +"""Wrapper around vLLM. Also handles synchronization.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +import vllm + +if TYPE_CHECKING: + import torch + + from .utils import SynchronizationBarriers + + +class AbstractSamplingEngine(ABC): + """Interface for the sampling engine.""" + + def __init__( + self, + vllm_llm: vllm.LLM | None = None, + sampling_params: vllm.SamplingParams | None = None, + synchronization_barriers: SynchronizationBarriers | None = None, + ) -> None: + """Initialize sampling engine. + + Params: + vllm_llm: Instance of vllm.LLM, required only for rank 0. + sampling_params: Optionally, specify default sampling params. + synchronization_barriers: Optionally, supply barriers to + prevent workers from accessing GPU while vLLM is running. + + """ + self.vllm_llm = vllm_llm + self.sampling_params = sampling_params + self.synchronization_barriers = synchronization_barriers + self.vllm_train_step = -1 + + @abstractmethod + def update(self, model: torch.nn.Module, train_step: int) -> None: + """Update model in sampling engine if the current copy is stale. + + Params: + model: PeftModel, up-to-date model + train_step: int, train step of the given model. + """ + if self.vllm_train_step != train_step: + # Update parameters of self.vllm_llm using the given `model``. + return + + @abstractmethod + def generate( + self, + prompts: list[str], + sampling_params: vllm.SamplingParams | None = None, + ) -> list[vllm.RequestOutput]: + """Generate continuation for the given prompts synchronously. + + Invoke at all ranks. Output will be broadcasted to all ranks. + + Params: + ------ + prompts: List of input prompts. + sampling_params: Optionally, override self.sampling_params in + this request only. + + Returns + ------- + Output from vllm: list[vllm.RequestOutput], one for each prompt. + + """ diff --git a/vectorlm/sampling/sampling_lora.py b/vectorlm/sampling/sampling_lora.py new file mode 100644 index 0000000..32dfbcc --- /dev/null +++ b/vectorlm/sampling/sampling_lora.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING + +import torch.distributed as dist +import vllm +from vllm.lora.request import LoRARequest + +from vectorlm.utils.save_utils import save_peft_adapter + +from .abstract import AbstractSamplingEngine +from .utils import SynchronizationBarriers, multiprocess_wrap + +if TYPE_CHECKING: + from peft.peft_model import PeftModel + + +class LoRASamplingEngine(AbstractSamplingEngine): + """Sampling engine optimized for LoRA PEFT.""" + + def __init__( + self, + vllm_llm: vllm.LLM | None = None, + sampling_params: vllm.SamplingParams | None = None, + synchronization_barriers: SynchronizationBarriers | None = None, + adapter_temp_folder: str | None = None, + ) -> None: + """Initialize sampling engine. + + Params: + vllm_llm: Instance of vllm.LLM, required only for rank 0. + sampling_params: Optionally, specify default sampling params. + adapter_temp_folder: Temporary path where temporary adapter weights + are saved. If not specified, f`/dev/shm/{job_id}` + """ + assert synchronization_barriers is not None + self.barriers = synchronization_barriers + self.sampling_params = sampling_params + + if adapter_temp_folder is not None: + self.adapter_temp_folder = adapter_temp_folder + else: + slurm_job_id_or_placeholder = os.environ.get("SLURM_JOB_ID", "0") + + # Manually specify the in-memory /dev/shm filesystem + # to avoid disk wear and overhead. + self.adapter_base_folder = "/dev/shm/" # noqa: S108 + self.adapter_temp_folder = os.path.join( + self.adapter_base_folder, + slurm_job_id_or_placeholder, + ) + + if dist.get_rank() == 0: + assert vllm_llm is not None + self.vllm_llm = vllm_llm + generate_fn_raw = vllm_llm.generate + else: + # placeholder, as the wrapped_fn won't be invoked outside rank-0. + generate_fn_raw = None + + self.generate_fn = multiprocess_wrap(generate_fn_raw, self.barriers) + self.vllm_train_step = -1 + + def update(self, model: PeftModel, train_step: int) -> None: + """Update model in sampling engine if the current copy is stale. + + Params: + model: PeftModel, up-to-date model + train_step: int, train step of the given model. + """ + self.barriers.before_generation.wait() + if self.vllm_train_step != train_step: + save_peft_adapter(model, self.adapter_temp_folder) + self.vllm_train_step = train_step + self.lora_request = LoRARequest( + "_vectorlm", + self.vllm_train_step + 1, + self.adapter_temp_folder, + ) + + self.barriers.after_generation.wait() + + def generate( + self, + prompts: list[str], + sampling_params: vllm.SamplingParams | None = None, + ) -> list[vllm.RequestOutput]: + """Generate continuation for the given prompts. Invoke at all ranks. + + Output will be broadcasted to all ranks. + + Params: + ------ + prompts: List of input prompts. + sampling_params: Optionally, override self.sampling_params in + this request only. + + Returns + ------- + Output from vllm: list[vllm.RequestOutput], one for each prompt. + + """ + return_value = self.generate_fn( + prompts, + sampling_params, + lora_request=self.lora_request, + use_tqdm=False, + ) + assert len(return_value) == len(prompts) + + return return_value diff --git a/vectorlm/sampling/utils.py b/vectorlm/sampling/utils.py new file mode 100644 index 0000000..df233c3 --- /dev/null +++ b/vectorlm/sampling/utils.py @@ -0,0 +1,339 @@ +"""Generic utils for the sampling engines.""" + +from __future__ import annotations + +import json +import os +import threading +import time +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, Iterable, NamedTuple, TypeVar + +from vllm import LLM, LLMEngine, SamplingParams +from vllm.engine.arg_utils import EngineConfig +from vllm.executor.multiproc_gpu_executor import MultiprocessingGPUExecutor +from vllm.utils import Counter + +from .abstract import AbstractSamplingEngine + +if TYPE_CHECKING: + from threading import Barrier + + from vllm.worker.worker_base import WorkerBase + + +VECTORLM_WORKER_INIT_RDZV_TIMEOUT = 7 + + +class SampleOutput(NamedTuple): + """Represents possible responses to a prompt. + + Params: + prompt: prompt string. + options: list of proposed responses to this prompt. + """ + + prompt: str + options: list[str] + time_taken: float + + +class SynchronizationBarriers(NamedTuple): + """Barriers for limiting GPU access concurrency. + + Params: + vllm_init: Barrier to Ensures that vLLM engine is fully initialized + before running any vectorlm logic. + + before_generation: Ensure all processes have reached this statement, + or vectorlm in some processes might still be accessing the + accelerator when rank 0 invokes vLLM. + + after_generation: Detain all processes until after rank 0 is sure that + there are no outstanding vLLM jobs. + """ + + vllm_init: Barrier + before_generation: Barrier + after_generation: Barrier + + +Fn = TypeVar("Fn", bound=Callable[..., Any]) + + +def multiprocess_wrap(fn: Fn | None, barriers: SynchronizationBarriers) -> Fn: + """Apply barrier to function and broadcast output. + + This wrapper function tries to preserve the type signature + of the wrapped function for the IDE. Tested for Pylance. + + While fn would be invoked only on rank 0, the wrapped function + should be invoked in the vectorlm thread at all ranks, so that + the barriers would block these threads from accessing GPU while + the fn is running. However, only the fn specified at rank 0 + will actually be executed. The fn parameter can be None at all + other ranks. + + Each rank would receive the same value as output. + + Params: + ------- + fn: Function to wrap. Output needs to be compatible with pickle. + This arg is required only on rank 0. + barriers: SynchronizationBarriers, only the before_generation and + after_generation barriers are required.. + + Returns + ------- + same output as Fn, but broadcasted to all ranks + (i.e., same value at all ranks) + + """ + + def _wrapped_fn(*args, **kwargs) -> ...: # noqa: ANN002,ANN003 + barriers.after_generation.wait() + + import torch.distributed # type: ignore[reportMissingImports] + + rank = torch.distributed.get_rank() + + # placeholder for output value, + # populate on rank 0 and then broadcast. + # torch requires placeholder element in object list. + output = [None] + if rank == 0: + assert fn is not None + output = [fn(*args, **kwargs)] + + # fn might access torch.dist, which might conflict with + # broadcast_object_list. Hence, keep all ranks witing until fn returns + # on rank 0. + barriers.before_generation.wait() + + torch.distributed.broadcast_object_list(output) + return output[0] + + return _wrapped_fn # type: ignore[reportReturnType] + + +class ManagedMultiProcGPUExecutor(MultiprocessingGPUExecutor): + """MultiProcGPUExecutor, but with VectorLM launched alongside vLLM. + + This class is compatible as an "executor_class" for the vLLM Engine. + + NCCL requires exactly one process for each GPU, so the vLLM and VectorLM + logic on each GPU need to fit into the same process. + + This class ensures that in each of these one-per-GPU processes, + VectorLM logic would run in a separate thread alongside the vLLM Worker. + """ + + # only missing parameter in vectorlm_fn is local_rank. + vectorlm_fn: Callable[[], None] + + def __init__(self, *args, **kwargs) -> None: # noqa: ANN002,ANN003 + """Copy class variable vectorlm_fn into this instance. + + Doing so ensures that spawned sub-processes also have access + to vectorlm_fn, which might not otherwise be accessible as a class + variable. + """ + self.vectorlm_fn = ManagedMultiProcGPUExecutor.vectorlm_fn + super().__init__(*args, **kwargs) + + def _create_worker( + self, + local_rank: int = 0, + *args, # noqa: ANN002 + **kwargs, # noqa: ANN003 + ) -> WorkerBase: + """Launch vectorlm thread and vLLM worker in the same process. + + For rank 0, this method is invoked "blocking" inside the rank-0 process. + + For rank != 0, this method is supposed to be invoked in a child process + spawned from the main rank-0 process. + """ + os.environ["LOCAL_RANK"] = str(local_rank) + os.environ["RANK"] = str(local_rank) + vectorlm_thread = threading.Thread( + target=self.vectorlm_fn, + name=f"Rank{local_rank}/vectorlm", + ) + vectorlm_thread.start() + + worker = super()._create_worker(*args, **kwargs, local_rank=local_rank) + assert worker is not None + worker.vectorlm_thread = vectorlm_thread + + return worker + + +class ManagedLLM(LLM): + """vllm.entrypoints.LLM but using an externally-initialized LLM Engine.""" + + def __init__(self, llm_engine: LLMEngine) -> None: + """Instantiate LLM instance using externally-initialized LLM Engine.""" + self.llm_engine = llm_engine + self.request_counter = Counter() + + +class SamplingEngineProvider: + """Provide VectorLM workers access to the SamplingEngine via a callback. + + The vLLM VectorLM logic needs to be launched alongside the vLLM worker + This class provides the VectorLM logic + + vLLM engine is initialized only after the initialize_engine call. + """ + + def __init__( + self, + engine_config: EngineConfig, + barriers: SynchronizationBarriers, + sampling_engine_class: AbstractSamplingEngine.__class__, + vectorlm_main_fn: Callable[ + [Callable[[], AbstractSamplingEngine]], + None, + ], + ) -> None: + """Instantiate class without initializing wrapped vLLM engine.""" + self.llm_engine: LLMEngine | None = None + self.llm: LLM | None = None + self.engine_config = engine_config + self.barriers = barriers + self.sampling_engine_class = sampling_engine_class + + # Only missing args is local_rank. + self.vectorlm_fn: Callable[[], None] = partial( + vectorlm_main_fn, + self.get_sampling_engine, + ) + + def initialize_engine(self) -> None: + """Initialize vLLM engine. + + Invoke this method only from the rank 0 __main__. + + This method blocks until all vectorlm threads (including rank 0) + have also reached the vllm_init barrier. + """ + ManagedMultiProcGPUExecutor.vectorlm_fn = self.vectorlm_fn + + self.llm_engine = LLMEngine( + **self.engine_config.to_dict(), + executor_class=ManagedMultiProcGPUExecutor, + log_stats=False, + ) + + self.llm = ManagedLLM(self.llm_engine) + print(f"Instantiated ManagedLLM: {self.llm}") + + thread_name = threading.current_thread().getName() + print(f"{thread_name}: vllm_init waiting") + + try: + self.barriers.vllm_init.wait(VECTORLM_WORKER_INIT_RDZV_TIMEOUT) + except threading.BrokenBarrierError as e: + msg = ( + "SamplingEngineProvider requires get_sampling_engine() to be " + "invoked across all VectorLM ranks (including rank 0) prior " + "to any Torch NCCL logic. \n" + "If sampling engine is not required, " + "please avoid using SamplingEngineProvider, as this provider " + "would launch vLLM and might hang because of concurrent NCCL " + "access. Launch the training script via torchrun instead." + ) + raise RuntimeError(msg) from e + + print(f"{thread_name}: vllm_init cleared") + + def get_sampling_engine(self) -> AbstractSamplingEngine: + """Instantiate sampling engine. + + Invoke this callback method from the VectorLM thread of each process, + including rank 0. + + SamplingEngine handles synchronization and prevents concurrent + NCCL access. Hence, a SamplingEngine instance shall be instantiated + regardless of the rank of the process. + + This method blocks until the vLLM Engine is fully initialized. + """ + thread_name = threading.current_thread().getName() + print(f"{thread_name}: vllm_init wait") + self.barriers.vllm_init.wait() + print(f"{thread_name}: vllm_init cleared") + + # vLLM is instantiated and required only for the rank 0 SamplingEngine. + assert (self.llm is not None) or (int(os.environ["LOCAL_RANK"]) != 0) + return self.sampling_engine_class( + self.llm, + SamplingParams(seed=0, temperature=0), + self.barriers, + ) + + def join_vectorlm_thread(self) -> None: + """Join the rank 0 (main process) vectorlm thread. + + Invoke this function only from __main__ (of rank 0) after + initialize_engine. + """ + assert self.llm_engine is not None + model_executor = self.llm_engine.model_executor + assert isinstance(model_executor, ManagedMultiProcGPUExecutor) + assert model_executor.driver_worker is not None + model_executor.driver_worker.vectorlm_thread.join() + + +def handle_sample( + sampling_engine: AbstractSamplingEngine, + prompts: Iterable[str], + output_path: str | None, + sampling_params: SamplingParams | None = None, + extra_data: dict[str, Any] | None = None, +) -> list[SampleOutput]: + """Sample continuations and optionally save to disk. + + Params: + ------ + sampling_engine: an instantiation of sampling engine. + prompts: a list (iterable) of prompts. + output_path: if provided, append output json lines to this file. + Recommended: specify output_path only on rank 0. + sampling_params: forwarded to sampling engine. + extra_data: prepended to each line of output (e.g., current epoch.) + + Returns + ------- + List of SampleOutput, representing continuations for each prompt. + + """ + _prompts = list(prompts) + + start_time = time.time() + generation_output = sampling_engine.generate(_prompts, sampling_params) + time_taken = time.time() - start_time + + # Parse sample engine output and keep only the output strings. + sample_outputs: list[SampleOutput] = [] + for prompt, request_output in zip(prompts, generation_output): + sample_outputs.append( + SampleOutput( + prompt, + [option.text for option in request_output.outputs], + time_taken, + ), + ) + + # note: always produce jsonl_output_lines to ensure code coverage. + extra_data = extra_data if extra_data is not None else {} + jsonl_output_lines: list[str] = [ + json.dumps({**extra_data, **sample_output._asdict()}) + for sample_output in sample_outputs + ] + if output_path is not None: + with open(output_path, "a") as output_jsonl_file: + output_jsonl_file.write("\n".join(jsonl_output_lines) + "\n\n") + + return sample_outputs diff --git a/vectorlm/tests/test_vllm.py b/vectorlm/tests/test_vllm.py new file mode 100644 index 0000000..8f61ed1 --- /dev/null +++ b/vectorlm/tests/test_vllm.py @@ -0,0 +1,310 @@ +"""vLLM Integration tests. + +LLM for scaffolding: +- gemma-2b (vLLM supports LoRA for Gemma but not OPT) + +Scaffolding and fixtures: +- vLLM + - Spin up vLLM Engine via Python API + - LoRA request via peft LoRA adapter from + - regular disk. + - folder saved in /dev/shm. +- References top-k log probabilities via vLLM + - Base model + - LoRA model loaded from /dev/shm + - LoRA model loaded from disk +""" + +from __future__ import annotations + +import os.path +import shutil +from typing import Generator + +import numpy as np +import pytest +import vllm +import vllm.sequence +from huggingface_hub import snapshot_download +from vllm.lora.request import LoRARequest + +BASE_MODEL_PATH = "/model-weights/gemma-2b" +LORA_ADAPTER_HF_HUB_REPO = "jacobthebanana/example-gemma-2b-lora-gsm8k" +LORA_ADAPTER_LOCAL_FOLDER = "data/example_lora_adapter" +NUM_TOP_LOGPROBS = 5 + + +@pytest.fixture(scope="session") +def lora_adapter_path() -> str: + """Download example LoRA adapters from HuggingFace hub. + + Returns + ------- + Path to the adapters on local filesystem. + + """ + if not os.path.exists(f"{LORA_ADAPTER_HF_HUB_REPO}"): + snapshot_download( + LORA_ADAPTER_HF_HUB_REPO, + local_dir=LORA_ADAPTER_LOCAL_FOLDER, + ) + + return LORA_ADAPTER_LOCAL_FOLDER + + +@pytest.fixture(scope="session") +def lora_adapter_path_dev_shm( + lora_adapter_path: str, +) -> Generator[str, None, None]: + """Create a copy of LoRA adapters within /dev/shm. + + Returns + ------- + Path to adapters in the /dev/shm filesystem. + + """ + # Specifically require /dev/shm since /tmp might go to an actual disk, + # incurring overhead and unnecessary SSD wear. + lora_adapter_dev_shm_path = f"/dev/shm/{LORA_ADAPTER_HF_HUB_REPO}" # noqa: S108 + os.makedirs(lora_adapter_dev_shm_path, exist_ok=True) + shutil.copytree( + lora_adapter_path, + lora_adapter_dev_shm_path, + dirs_exist_ok=True, + ) + print(f"Copy: {lora_adapter_path}, {lora_adapter_dev_shm_path}") + + yield lora_adapter_dev_shm_path + + # Clean up to free memory. + shutil.rmtree(lora_adapter_dev_shm_path) + + +@pytest.fixture(scope="session") +def vllm_model() -> vllm.LLM: + """Spin up vLLM base model.""" + return vllm.LLM( + BASE_MODEL_PATH, + gpu_memory_utilization=0.3, + enable_lora=True, + ) + + +@pytest.fixture(scope="session") +def vllm_sampling_params() -> vllm.SamplingParams: + """Return example vLLM sampling parameters for consistency across tests.""" + return vllm.SamplingParams( + logprobs=NUM_TOP_LOGPROBS, + temperature=0.5, + seed=1, + ) + + +@pytest.fixture(scope="session") +def example_prompts() -> list[str]: + """Return example prompts.""" + return [ + "Vector Institute is located in", + "The answer to life the universe and everything is ", + "Vector Institute is located in", + ] + + +def extract_logprobs( + vllm_responses: list[vllm.RequestOutput], +) -> list[list[vllm.sequence.SampleLogprobs]]: + """Extract logprobs from vllm response. + + Additionally, ensures none of these output is None. + + Params + ------ + vllm_responses: output from LLM.generate() + + Returns + ------- + Nested list, one list of logprobs instance for each prompt. + + """ + logprobs_responses: list[list[vllm.sequence.SampleLogprobs]] = [] + for response in vllm_responses: + for output in response.outputs: + logprobs_options: list[vllm.sequence.SampleLogprobs] = [] + logprobs = output.logprobs + assert logprobs is not None + logprobs_options.append(logprobs) + + logprobs_responses.append(logprobs_options) + + return logprobs_responses + + +def assert_logprobs_allclose( + logprobs_a: vllm.sequence.SampleLogprobs, + logprobs_b: vllm.sequence.SampleLogprobs, +) -> None: + """Ensure that logprobs_a are all close with logprobs_b.""" + assert len(logprobs_a) == len(logprobs_b) + for token_logprobs_a, token_logprobs_b in zip(logprobs_a, logprobs_b): + assert token_logprobs_a.keys() == token_logprobs_b.keys() + token_logprobs_a_array = np.asarray( + [token_logprobs_a[k].logprob for k in token_logprobs_a], + ) + token_logprobs_b_array = np.asarray( + [token_logprobs_b[k].logprob for k in token_logprobs_a], + ) + assert np.allclose(token_logprobs_a_array, token_logprobs_b_array) + + +@pytest.fixture(scope="session") +def base_llm_logprobs( + vllm_model: vllm.LLM, + example_prompts: list[str], + vllm_sampling_params: vllm.SamplingParams, +) -> list[list[vllm.sequence.SampleLogprobs]]: + """Return logprobs for base LLM (no LoRA adapter).""" + vllm_responses = vllm_model.generate(example_prompts, vllm_sampling_params) + return extract_logprobs(vllm_responses) + + +def get_lora_llm_logprobs( + vllm_model: vllm.LLM, + example_prompts: list[str], + vllm_sampling_params: vllm.SamplingParams, + _lora_adapter_path_fixture_name: str, + request: pytest.FixtureRequest, +) -> list[list[vllm.sequence.SampleLogprobs]]: + """Return logprobs for LoRA-adapted LLM.""" + lora_adapter_path = request.getfixturevalue(_lora_adapter_path_fixture_name) + lora_request = LoRARequest("example_adapter", 1, lora_adapter_path) + vllm_responses = vllm_model.generate( + example_prompts, + vllm_sampling_params, + lora_request=lora_request, + ) + return extract_logprobs(vllm_responses) + + +@pytest.fixture(scope="session") +def lora_llm_logprobs_local_and_dev_shm( + vllm_model: vllm.LLM, + example_prompts: list[str], + vllm_sampling_params: vllm.SamplingParams, + request: pytest.FixtureRequest, +) -> tuple[list[list[vllm.sequence.SampleLogprobs]], ...]: + """Return logprobs via LoRA adapter loaded locally and from /dev/shm. + + Returns + ------- + Two list of lists (options) of vLLM logprobs. + local_adapter_logprobs, dev_shm_adapter_logprobs + + """ + return tuple( + get_lora_llm_logprobs( + vllm_model, + example_prompts, + vllm_sampling_params, + _adapter_path, + request, + ) + for _adapter_path in [ + "lora_adapter_path", + "lora_adapter_path_dev_shm", + ] + ) + + +@pytest.fixture(scope="session") +def lora_llm_logprobs_local( + lora_llm_logprobs_local_and_dev_shm: tuple[ + list[list[vllm.sequence.SampleLogprobs]], + ..., + ], +) -> list[list[vllm.sequence.SampleLogprobs]]: + """Return logprobs from locally-loaded LoRA adapters.""" + return lora_llm_logprobs_local_and_dev_shm[0] + + +@pytest.fixture(scope="session") +def lora_llm_logprobs_dev_shm( + lora_llm_logprobs_local_and_dev_shm: tuple[ + list[list[vllm.sequence.SampleLogprobs]], + ..., + ], +) -> list[list[vllm.sequence.SampleLogprobs]]: + """Return logprobs from LoRA adapters loaded via /dev/shm ram-disk.""" + return lora_llm_logprobs_local_and_dev_shm[1] + + +# Reuse this test case definition across base and LoRA logprobs. +@pytest.mark.parametrize( + "logprobs_fixture_name", + [ + "base_llm_logprobs", + "lora_llm_logprobs_local", + "lora_llm_logprobs_dev_shm", + ], +) +def test_logprobs_consistency( + logprobs_fixture_name: str, + request: pytest.FixtureRequest, +) -> None: + """Verify consistency of logprobs. + + Since vLLM seed is fixed, the same prompt should produce + the same logprobs. + """ + logprobs: list[list[vllm.sequence.SampleLogprobs]] = ( + request.getfixturevalue(logprobs_fixture_name) + ) + + assert_logprobs_allclose(logprobs[0][0], logprobs[2][0]) + + with pytest.raises(AssertionError): + assert_logprobs_allclose(logprobs[2][0], logprobs[1][0]) + + +def test_compare_ref_logprobs( + base_llm_logprobs: list[list[vllm.sequence.SampleLogprobs]], + lora_llm_logprobs_local_and_dev_shm: tuple[ + list[list[vllm.sequence.SampleLogprobs]], + ..., + ], +) -> None: + """Ensure base_llm_logprobs are different from lora_llm_logprobs.""" + # Test both lora_adapter options: disk and ram-disk + for lora_llm_logprobs in lora_llm_logprobs_local_and_dev_shm: + for base_llm_seq_logprobs, lora_llm_seq_logprobs in zip( + base_llm_logprobs, + lora_llm_logprobs, + ): + with pytest.raises(AssertionError): + assert_logprobs_allclose( + base_llm_seq_logprobs[0], + lora_llm_seq_logprobs[0], + ) + + +def test_compare_lora_logprobs( + lora_llm_logprobs_local_and_dev_shm: tuple[ + list[list[vllm.sequence.SampleLogprobs]], + ..., + ], +) -> None: + """Ensure LoRA logprobs from local and ram-disk adapters match.""" + for logprobs_local_seq, logprobs_dev_shm_seq in zip( + *lora_llm_logprobs_local_and_dev_shm, + ): + # Each of these represents the logprobs options for a sequence output. + logprobs_local_seq: list[vllm.sequence.SampleLogprobs] + logprobs_dev_shm_seq: list[vllm.sequence.SampleLogprobs] + + assert_logprobs_allclose(logprobs_local_seq[0], logprobs_dev_shm_seq[0]) + sequence_tokens = "".join( + [ + str(next(iter(token.values())).decoded_token) + for token in logprobs_local_seq[0] + ], + ) + print(f"\nVerified equivalence: {sequence_tokens}") diff --git a/vectorlm/trainer.py b/vectorlm/trainer.py index fc52c36..c571bf7 100644 --- a/vectorlm/trainer.py +++ b/vectorlm/trainer.py @@ -2,16 +2,16 @@ import math import os -from typing import Any +from typing import TYPE_CHECKING, Any import peft import torch import torch.distributed as dist +import wandb from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler, ReduceLROnPlateau from transformers import PreTrainedTokenizer -import wandb from vectorlm.dataset import Dataset from vectorlm.utils.data_utils import Config from vectorlm.utils.save_utils import ( @@ -26,6 +26,9 @@ save_scheduler, ) +if TYPE_CHECKING: + from vectorlm.sampling import AbstractSamplingEngine + class Trainer: """Main trainer class. @@ -85,13 +88,17 @@ def __init__( self.max_steps = None self.saving_steps = None self._post_process(original_dataset_length) - self.peft_method: str | None = None self.is_peft_adapter_restored: bool = False if "lora_peft_config" in self.config: self.peft_method = peft.utils.peft_types.PeftType.LORA + self.peft_method: str | None = None + self.is_peft_adapter_restored: bool = False + + self.sampling_engine: AbstractSamplingEngine | None = None + def _post_process(self, ds_orig_length: int) -> None: """Calculate steps for weight updates and saving.""" sharded_ds_orig_len = math.ceil( @@ -115,6 +122,7 @@ def prepare_trainer( dataset: Dataset, optimizer: Optimizer, lr_scheduler: LRScheduler | ReduceLROnPlateau, + sampling_engine: AbstractSamplingEngine | None = None, is_peft_adapter_restored: bool = False, ) -> None: """Set all essential training requirements. @@ -127,6 +135,9 @@ def prepare_trainer( optimizer: The training optimizer. lr_scheduler: The LR scheduler. + sampling_engine: Optionally, provide a sampling engine to enable + sampling during training. + is_peft_adapter_restored: whether peft is enabled and adapters were restored from filesystem. @@ -137,6 +148,8 @@ def prepare_trainer( self.optimizer = optimizer self.lr_scheduler = lr_scheduler + self.sampling_engine = sampling_engine + self.is_peft_adapter_restored = is_peft_adapter_restored def save_checkpoint(self, epoch: int) -> None: @@ -233,7 +246,7 @@ def find_checkpoint(self, checkpoint_dir: str) -> int: """ checkpoint = checkpoint_exists(checkpoint_dir) - if checkpoint: + if (checkpoint) and (self.config.checkpointing_enabled): main_ckpt_dir = os.path.join(checkpoint_dir, "checkpoints") latest_ckpt_dir = get_latest_checkpoint_dir(main_ckpt_dir) full_ckpt_dir = os.path.join(main_ckpt_dir, latest_ckpt_dir) @@ -270,6 +283,24 @@ def step( test_loss = None if self.tr_step % self.logging_steps == 0: test_loss = self.eval_step(epoch) + + if (self.sampling_engine is not None) and ( + self.tr_step % self.config.sampler.sample_frequency == 0 + ): + from vectorlm.sampling import handle_sample + + self.sampling_engine.update(self.model, self.tr_step) + handle_sample( + self.sampling_engine, + self.config.sampler.prompts, + output_path=( + self.config.sampler.output_jsonl_path + if dist.get_rank() == 0 + else None + ), + extra_data={"train_step": self.tr_step}, + ) + self.tr_step += 1 return train_loss, test_loss diff --git a/vectorlm/utils/misc_utils.py b/vectorlm/utils/misc_utils.py index 1c59b98..081deb2 100644 --- a/vectorlm/utils/misc_utils.py +++ b/vectorlm/utils/misc_utils.py @@ -4,8 +4,8 @@ from typing import Any import torch.distributed as dist - import wandb + from vectorlm.utils.data_utils import Config diff --git a/vectorlm/utils/save_utils.py b/vectorlm/utils/save_utils.py index cd228dc..6c9d501 100644 --- a/vectorlm/utils/save_utils.py +++ b/vectorlm/utils/save_utils.py @@ -173,8 +173,10 @@ def save_peft_adapter( StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True), ): - if dist.get_rank() == 0: - model.save_pretrained(output_path) + model.save_pretrained( + output_path, + is_main_process=(dist.get_rank() == 0), + ) def save_model_and_optimizer(