Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement vLLM FSDP LoRA hot-swapping integration #10

Open
wants to merge 92 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
904d1e1
Implemented baseline LoRA peft for one Nvidia GPU.
jacobthebanana Feb 26, 2024
2ace67e
Added support for saving lora adapters.
jacobthebanana Feb 27, 2024
a25e667
save_utils: added support for non-FSDP optimizers.
jacobthebanana Feb 29, 2024
65a2dbf
example_lora: highlighted current lora (non-fsdp) limitations.
jacobthebanana Feb 29, 2024
ed4c84f
Added instructions on LoRA on one GPU.
jacobthebanana Feb 29, 2024
5a72392
Added example script for launching lora.
jacobthebanana Feb 29, 2024
e176ac8
Revised instructions on LoRA on one GPU.
jacobthebanana Feb 29, 2024
2d869b0
Implemented LoRA FSDP.
jacobthebanana Mar 6, 2024
dc098d6
Reverted automatic formatter changes in README.md
jacobthebanana Mar 6, 2024
5a1fd76
Eliminated non-FSDP logic from save_utils.
jacobthebanana Mar 6, 2024
7e187bc
Moved lora config out of example config.yaml.
jacobthebanana Mar 6, 2024
3eea331
Implemented LoRA benchmarking logic for worker.
jacobthebanana Mar 11, 2024
906e4f3
model_utils: Refactored get_lora_model to reduce interface width. (th…
jacobthebanana Mar 11, 2024
0c41535
test_modelling: moved text output to data/.
jacobthebanana Mar 11, 2024
f24d2fa
added example yaml config for lora benchmarking.
jacobthebanana Mar 11, 2024
7d27d90
launch_benchmark: marked qos flag as optional.
jacobthebanana Mar 11, 2024
d22ea85
launch_benchmark: added option to limit number of jobs launched.
jacobthebanana Mar 11, 2024
84b953a
launch_benchmark: implemented torch profiler integration.
jacobthebanana Mar 11, 2024
e1cda07
Merged changes from low CPU memory usage feature (#6) into jjt/lora-b…
adil-a Mar 11, 2024
48f61d9
Revised launch_benchmark.py to use new profiling path.
jacobthebanana Mar 11, 2024
9876ebe
Enabled automatic creation of data/trace folder.
jacobthebanana Mar 11, 2024
5330871
Added instructions for profiling tools.
jacobthebanana Mar 11, 2024
17e24bd
Merge remote-tracking branch 'origin/master' into jjt/lora-baseline
jacobthebanana Mar 11, 2024
9982791
Cleaned up duplicate imports from merge.
jacobthebanana Mar 11, 2024
9a76e80
Cleaned up duplicate imports from merge.
jacobthebanana Mar 11, 2024
ffa7067
Cleaned up parse_benchmark.py
jacobthebanana Mar 11, 2024
bd893e1
Integrated LoRA logic into llama_example.py.
jacobthebanana Mar 11, 2024
c2f346f
Moved lora_configs into train_parameters in config yaml. Adjusted doc…
jacobthebanana Mar 11, 2024
56cb750
Revised handling of nproc-per-node in benchmark script.
jacobthebanana Mar 12, 2024
97ddd8c
Included parameter_count info in benchmark output.
jacobthebanana Mar 12, 2024
7c7a000
Implemented basic util for parsing benchmarking output.
jacobthebanana Mar 12, 2024
f33e89a
model_utils: Enabled low_cpu_mem_usage in auto model from_pretrained…
jacobthebanana Mar 12, 2024
35bdbcd
launch_lora_benchmark.sh: implemented automatic identification of num…
jacobthebanana Mar 13, 2024
e6b2e59
requirements.txt: included accelerate to support low_cpu_mem loading.
jacobthebanana Mar 13, 2024
db148fa
benchmark.py: adjusted BenchmarkingDataset to avoid StopIteration exc…
jacobthebanana Mar 13, 2024
35f6c5d
benchmark.py: added env var flag to toggle export_trace
jacobthebanana Mar 15, 2024
4a1251b
parse_benchmark: included profiler table in output file.
jacobthebanana Mar 15, 2024
79fd79b
get_lora_model_from_base_model: enabled peft for models loaded via lo…
jacobthebanana Mar 15, 2024
5c25397
model_utils: revised dtype handling for peft-wrapped models.
jacobthebanana Mar 15, 2024
c19de82
parse_benchmark: implemented sorting of profiler table output.
jacobthebanana Mar 15, 2024
7e13cde
Merged example_lora into examples/llama_example.pu
jacobthebanana Mar 15, 2024
28d4ede
Added instructions related to parse_benchmark
jacobthebanana Mar 15, 2024
a863ed2
parse_benchmark: implemented aggregation across repeated metrics.
jacobthebanana Mar 15, 2024
eb3721a
Implemented non-LoRA profiling and benchmarking.
jacobthebanana Apr 9, 2024
37f5dec
Various static typechecking and formatting fixes.
jacobthebanana Apr 11, 2024
78c6faf
Implemented restoring LoRA train state from filesystem.
jacobthebanana Apr 15, 2024
aea2ed8
Included train step number in LoRA adapter output path.
jacobthebanana Apr 15, 2024
dad6553
Added reference throughput table to documentation.
jacobthebanana Apr 16, 2024
bbcda75
Added unit description to reference throughput table.
jacobthebanana Apr 16, 2024
d397488
Added unit description to reference throughput table.
jacobthebanana Apr 16, 2024
35b97b8
Benchmark: added option to override max_length of pre-trained model.
jacobthebanana Apr 16, 2024
6af7791
Deleted unused `accelerate` dependency from requirements.txt
jacobthebanana Apr 16, 2024
97be477
Benchmark: added comment on max_length.
jacobthebanana Apr 16, 2024
b43e565
Benchmark: added comment on batch size.
jacobthebanana Apr 16, 2024
607de70
Benchmark: added option to override batch size.
jacobthebanana Apr 16, 2024
bdef48f
Benchmark throughput documentation: revised word choices.
jacobthebanana Apr 16, 2024
3294a39
LoRA Hot-Swap: Implemented vLLM integration test scaffolding and PyTe…
jacobthebanana Apr 16, 2024
2bb7bad
LoRA Hot-Swap: Implemented vLLM LoRA hot-swap integration proof-of-co…
jacobthebanana Apr 16, 2024
5d93afe
LoRA Hot-Swap: added additional fixtures to enhance readability.
jacobthebanana Apr 16, 2024
02988a5
LoRA Hot-Swap: Deleted redundant np.asarray call in integration test …
jacobthebanana Apr 17, 2024
5ad5d90
LoRA Hot-Swap: Updated test case documentations to reflect code reuse…
jacobthebanana Apr 17, 2024
afb321c
Moved profiling-tracking logic out of Trainer.
jacobthebanana Apr 17, 2024
5babf6b
Eliminated hasattr check related to no_sync since FSDP is always enab…
jacobthebanana Apr 17, 2024
c1b31c4
Replaced peft fsdp_auto_wrap_policy to eliminate implicit `accelerate…
jacobthebanana Apr 17, 2024
f0b201c
Configured LoRA auto-wrap policy as off by default- enable the policy…
jacobthebanana Apr 17, 2024
429ec5e
Revised punctuation in lora_requires_grad_policy_fn.
jacobthebanana Apr 17, 2024
afbc061
Renamed declarative `enable_lora` with descriptive `is_lora_enabled`.
jacobthebanana Apr 17, 2024
7bc6f89
Merge commit 'afbc061' from jjt/lora-baseline into jjt/lora-vllm-hotswap
jacobthebanana Apr 19, 2024
4936b1d
Added (request for comment) AbstractInferenceEngine interface and LoR…
jacobthebanana Apr 22, 2024
aa1fe8b
Renamed "inference" to "sampling".
jacobthebanana Apr 22, 2024
675367b
Added reference sampling steps to llama_example. Added example sampli…
jacobthebanana Apr 25, 2024
ca2cad8
Added train_parameters.get("sampler").
jacobthebanana Apr 25, 2024
649a4b8
[WIP] Implemented vLLM wrapper combining vectorlm and vLLM workers.
jacobthebanana May 6, 2024
ebb7bc9
vllm integration: Eliminated duplicate vllm ResultHandler.
jacobthebanana May 6, 2024
1f1f88e
vllm integration [WIP]: Revised vectorlm-vllm concurrency handling.
jacobthebanana May 6, 2024
11a1ba5
vllm integration [WIP]: Implemented inference during training.
jacobthebanana May 6, 2024
b697dc0
vllm integration [WIP]: Implemented lora hotswap.
jacobthebanana May 7, 2024
112ea3c
vllm integration [WIP]: Moved sampler-related logic into Trainer.
jacobthebanana May 9, 2024
07405dc
Merge remote-tracking branch 'origin/master' into jjt/lora-vllm-hotswap
jacobthebanana May 9, 2024
e707987
vllm integration: Added documentation on sampling engine.
jacobthebanana May 10, 2024
61c39ad
vllm integration: Added documentation on sampling engine.
jacobthebanana May 10, 2024
609c023
[WIP] vllm hotswapping: Implement minimum-viable wrapper for vllm/main.
jacobthebanana May 23, 2024
9585c01
[WIP] vllm hotswapping: Reduced area of vLLM integration interface.
jacobthebanana May 23, 2024
31464aa
vllm hotswapping [WIP]: Reduced area of vLLM integration interface.
jacobthebanana May 23, 2024
059d57f
vllm hotswapping [WIP]: Refactored vLLM integration interface to mini…
jacobthebanana May 24, 2024
b5c6389
vllm hotswapping [WIP]: deleted unneded torch dist.barrier from llama…
jacobthebanana May 24, 2024
f506812
vllm hotswapping [WIP]: documentation fixes and cleanup.
jacobthebanana May 24, 2024
3e27e84
vllm hotswapping [WIP]: cleaned up documentation related to multiproc…
jacobthebanana May 24, 2024
879399f
vllm hotswapping [WIP]: cleaned up changes in llama_example.py.
jacobthebanana May 24, 2024
bc0ae52
vllm hotswapping [WIP]: added example gemma sampling config.
jacobthebanana May 24, 2024
5e8944d
vllm hotswapping: Refactoring and cleanup.
jacobthebanana Jun 18, 2024
2005a7d
vllm hotswapping: Moved Sampler import into conditional block to avoi…
jacobthebanana Jun 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ data/
**/*.pyc
/.cache
/.vscode
/data
/data
/env
14 changes: 12 additions & 2 deletions configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@ train_parameters:
low_cpu_mem_usage: True

# LoRA config: uncomment the block below to enable LoRA

# lora_peft_config:
# task_type: CAUSAL_LM
# inference_mode: False
# r: 8
# lora_alpha: 32
# lora_dropout: 0.1


# Gradient norm clipping
max_grad_norm: 1
gradient_accumulation_steps: 4
Expand All @@ -50,6 +49,17 @@ train_parameters:
logging_steps: 500
save_frequency: 0.25

# Sampling during training
# Uncomment the block below to enable.

# sampler:
# sample_frequency: 8
# output_jsonl_path: data/output.jsonl
# prompts:
# - "Vector Institute of the"
# - "Vector Institute is located in the city of"
# - "The answer to life the universe and everything is"

dataset:
ignore_index: -100
eval_bs: 8
Expand Down
16 changes: 16 additions & 0 deletions docs/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,22 @@ Similar to the wandb config above, these keyword parameters are fed directly int
* `logging_steps`: How often evaluation is run using the evaluation dataset.
* `save_frequency`: The frequency at which checkpointing occurs. This must be between 0 and 1.


### Sampling during Training

To disable sampling during training, comment out the entire "sampling" section.

* `sample_frequency`: Number of train steps between two consecutive sampling steps.
* `output_jsonl_path`: Optional; write sampled output to the specified jsonl file.
* `prompts`: YAML list of prompt strings.

Each line of the output jsonl file would be a dictionary with keys:

* `tr_step`: number (integer), trainer step when this line was generated.
* `prompt`: string.
* `options`: list of strings, one for each possible option that the sampler provided.
* `time_taken`: float, number of seconds taken to generate **all** prompts at this step.

## Dataset

* `ignore_index`: The integer index used to ignore a given token in the loss calculation. Cross-entropy loss by default uses `-100`.
Expand Down
26 changes: 26 additions & 0 deletions docs/sampling.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Efficient Sampling during training

Some training objectives, noteably PPO, require "sampling" from the language model many times during training. The most straightforward approach might be to invoke model.generate on the model from within the training loop. Nevertheless, there have been a number of alternative inference approaches, including vLLM and others, promising over 10x the sampling throughput in terms of tokens generated per second when using a large sampling batch size. If model.generate is taking up too much of the training time, it might be worthwhile looking into these third-party solutions for speeding up the sampling process.

One main challenge of running these third-party solutions, however, is that most of them assume that the weights of the language model are fixed, such that there isn't a straightforward way of updating these weights. Usually, updating the weights requires restarting the sampling engine, which sometimes take minutes. At the same time, the performance of PPO and similar techniques heavily rely on the ability to replace the weights efficiently, or else the training would no longer be on-policy and convergence would take substantially more training steps. To resolve this issue, we implemented techniques to "hot-swap" the model parameters that are used in the sampling process.

Additionally, it is not straightforward to ensure a consistently high GPU utilization when combining sampling with training.
This repository enables you to make the most out of all your GPUs by fitting vLLM and your training loop into the same set of devices. This way, none of the GPUs would sit idle- if a GPU is not running training, it would be busy sampling using vLLM. These slides ([link](https://docs.google.com/presentation/d/1FCa5O8RYYkRRCAAcXhqCvomePo5fEfhjQciSteTEJ30/edit?usp=sharing)) provide an overview of the architecture behind this approach.

## Example- Supervised fine-tuning

We provide a basic example that samples from the language model while fine-tuning using a basic causal language modelling objective. To run the example, uncomment the "sampler" section in your configuration yaml, choose a port for `nccl` coordination, and run the following command (not using torchrun):

```
export MASTER_ADDR=127.0.0.1
export MASTER_PORT=19132
python3 examples/llama_example_mp.py \
--yaml_path configs/config.yaml \
--world_size 2
```

## Bring your own training loop

While the reference implementation is only for supervised fine-tuning, we provide abstractions that make it easier for you to implement your own training loop- be it PPO RLHF, TWIST, or something else. The goal is to abstract away all the synchronization logic, so that a training loop you've built on one GPU could scale to multiple GPUs on the same server with minimal modifications.

To get started, refer to examples/llama_example.py and vectorlm/trainer.py. Usually, the vLLM Engine is accessible only from the rank 0, making synchronization challenging. When invoked through llama_example_mp, the `SamplingEngine` interface in VectorLM enables your training loop to access vLLM.LLM.generate from all ranks, returning the same result across all ranks. Note that because the synchronization barriers require all ranks to reach the synchronization point, you need to invoke `generate` from all ranks.
26 changes: 26 additions & 0 deletions examples/launch_lora.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/bin/bash
#SBATCH --job-name=llama7b-2
#SBATCH --nodes=1
#SBATCH --mem=0
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-gpu=6
#SBATCH --gres=gpu:4
#SBATCH --output=llama-2-7b.%j.out
#SBATCH --error=llama-2-7b.%j.err
#SBATCH --partition=a100
#SBATCH --qos=your_assigned_qos # CHANGE
#SBATCH --open-mode=append
#SBATCH --wait-all-nodes=1
#SBATCH --time=3-00

export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag.
export NCCL_DEBUG=WARN
export NCCL_DEBUG_SUBSYS=WARN

# export TORCH_DISTRIBUTED_DEBUG=DETAIL # Uncomment these flags for debugging communication
# export TORCH_CPP_LOG_LEVEL=INFO
export LOGLEVEL=INFO
export PYTHONFAULTHANDLER=1
# export CUDA_LAUNCH_BLOCKING=0

torchrun --nnodes=1 --nproc-per-node=${SLURM_GPUS_ON_NODE} example_lora.py --yaml_path configs/config-lora.yaml
26 changes: 26 additions & 0 deletions examples/launch_lora_one_gpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/bin/bash
#SBATCH --job-name=llama7b-2-lora
#SBATCH --nodes=1
#SBATCH --mem=32GB
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-gpu=6
#SBATCH --gres=gpu:1
#SBATCH --output=llama-2-7b-lora.%j.out
#SBATCH --error=llama-2-7b-lora.%j.err
#SBATCH --partition=a100
#SBATCH --qos=your_assigned_qos # CHANGE
#SBATCH --open-mode=append
#SBATCH --wait-all-nodes=1
#SBATCH --time=3-00

export NCCL_IB_DISABLE=1 # Our cluster does not have InfiniBand. We need to disable usage using this flag.
export NCCL_DEBUG=WARN
export NCCL_DEBUG_SUBSYS=WARN

# export TORCH_DISTRIBUTED_DEBUG=DETAIL # Uncomment these flags for debugging communication
# export TORCH_CPP_LOG_LEVEL=INFO
export LOGLEVEL=INFO
export PYTHONFAULTHANDLER=1
# export CUDA_LAUNCH_BLOCKING=0

torchrun --nnodes=1 --nproc-per-node=1 example_lora.py --yaml_path configs/config-lora.yaml
60 changes: 55 additions & 5 deletions examples/llama_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import sys
from argparse import Namespace
from typing import TYPE_CHECKING, Callable

import torch
import torch.distributed as dist
Expand All @@ -30,6 +31,26 @@
save_peft_adapter,
)

if TYPE_CHECKING:
from vectorlm.sampling.utils import AbstractSamplingEngine


SAMPLER_NOT_PROVIDED_ERROR_MSG = """
Hot-swap sampling is enabled but sampler engine is not provided. \
Did you launch this script via `torchrun llama_example.py`? \
To enable hotswap vLLM sampling during training, launch the \
training script via `python3 lora_hotswap_example.py` directly \
without using Torchrun, especially when running in multi-GPU environments. \

Custom logic in lora_hotswap_example are required to handles multi-GPU \
synchronization and prevent NCCL conflicts with vLLM Engine when running \
in multi-GPU setups. \

If you have renamed llama_example.py, be sure to adjust the import in \
lora_hotswap_example.py to load the correct `main` function for the training \
loop.
"""


def parse_args() -> Namespace:
"""Parse command-line arguments.
Expand All @@ -48,8 +69,30 @@ def parse_args() -> Namespace:
return parser.parse_args()


def main(config: Config) -> None:
"""Define the main calling function."""
def main(
config: Config,
get_sampling_engine: Callable[[], AbstractSamplingEngine] | None = None,
jacobthebanana marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""Define the main calling function.

WORLD_SIZE, LOCAL_RANK, and RANK are retrieved from environment vars.

Args:
----
config: vectorlm config, e.g., loaded from yaml
get_sampling_engine: optional, blocking function that initializes the
sampling engine. Required if sampling during training is needed.
This method is provided in _VLLMCallbackWrapper. To avoid concurrent
nccl access, be sure to invoke this method before any torch method
that might also access nccl.

"""
sampling_engine = (
get_sampling_engine() if get_sampling_engine is not None else None
)
if config.train_parameters.get("sampler") is not None:
assert sampling_engine is not None, SAMPLER_NOT_PROVIDED_ERROR_MSG

training_args = config.train_parameters

# set a seed
Expand All @@ -66,7 +109,7 @@ def main(config: Config) -> None:
torch.cuda.empty_cache()

# setup wandb
if rank == 0:
if rank == 0 and config.enable_wandb_logging:
wandb_setup(config, **config.wandb_config)
dist.barrier()

Expand All @@ -87,7 +130,9 @@ def main(config: Config) -> None:
is_lora_enabled = True
peft_adapter_path = None
# Restore peft adapter from filesystem if available.
if checkpoint_exists(training_args.output_dir):
if (training_args.checkpointing_enabled) and checkpoint_exists(
training_args.output_dir,
):
peft_adapter_path = os.path.join(
training_args.output_dir,
"checkpoints",
Expand Down Expand Up @@ -115,6 +160,9 @@ def main(config: Config) -> None:
training_args.low_cpu_mem_usage,
is_lora_enabled,
)
# Trigger FSDP initialization before retrieving weights.
# Otherwise FSDP is_root flag might be set incorrectly.
model(input_ids=torch.zeros((1, 1), dtype=torch.int))

# load dataset
dataset = Dataset(
Expand Down Expand Up @@ -151,6 +199,7 @@ def main(config: Config) -> None:
dataset,
optimizer,
lr_scheduler,
sampling_engine,
is_peft_adapter_restored,
)

Expand All @@ -159,7 +208,6 @@ def main(config: Config) -> None:
checkpointed_epoch = trainer.find_checkpoint(training_args.output_dir)

for epoch in range(checkpointed_epoch, training_args.epochs):
trainer.model.train()
train_dl_iterator = iter(dataset.train_dataloader)
for _ in tqdm(
range(len(dataset.train_dataloader)),
Expand All @@ -185,6 +233,8 @@ def main(config: Config) -> None:
save_consolidated_model(trainer.model, hf_save_dir, rank)
dataset.reset_dataloaders()

sys.exit(0)


if __name__ == "__main__":
args = parse_args()
Expand Down
85 changes: 85 additions & 0 deletions examples/lora_hotswap_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Supply LoRASamplingEngine to llama_example.

Each non-rank-0 worker process should spawn vectorlm logic in a
separate thread (but same process) but won't run the actual
vectorlm logic until the vLLM Engine is initialized- inference
weights loaded into each worker.

To do so without duplicating vLLM code, observe that only the main process
(rank 0) is aware that vLLM engine was initialized properly
(when LLMEngine.__init__ returns.) Hence, one way to implement this
setup is to block the vectorlm thread with a multiprocessing synchronization
feature (e.g., a Barrier shared across all processes) that the rank 0 process
can remotely unblock.

See docs.google.com/presentation/d/1FCa5O8RYYkRRCAAcXhqCvomePo5fEfhjQciSteTEJ30
for more detail on this architecture.
"""

from __future__ import annotations

import argparse
import os
from functools import partial

from llama_example import main
from vllm import EngineArgs
from vllm.executor.multiproc_worker_utils import ResultHandler, mp

from vectorlm.sampling import (
LoRASamplingEngine,
SamplingEngineProvider,
SynchronizationBarriers,
)
from vectorlm.utils.data_utils import Config

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--world_size", type=int, default=1)
parser.add_argument("--yaml_path", type=str, required=True)
args = parser.parse_args()

world_size: int = args.world_size
vectorlm_config = Config(yaml_path=args.yaml_path)
sampler_config = vectorlm_config.train_parameters.sampler # type: ignore[reportAttributeAccessIssue]
vllm_engine_config = EngineArgs(
model=vectorlm_config.model, # type: ignore[reportAttributeAccessIssue]
gpu_memory_utilization=sampler_config.get(
"gpu_memory_utilization",
0.35,
),
tensor_parallel_size=world_size,
dtype=sampler_config.get("vllm_dtype", "auto"),
enable_lora=True,
).create_engine_config()
os.environ["WORLD_SIZE"] = str(world_size)

# Block all N vectorlm threads until main process finished
# initializing vLLM Engine. Additionally, block vectorlm
# threads as long as vLLM tasks are running.
barriers = SynchronizationBarriers(
# (n+1) threads: __main__, and n vectorlm threads (including main).
vllm_init=mp.Barrier(world_size + 1),
# n vectorlm threads.
before_generation=mp.Barrier(world_size),
after_generation=mp.Barrier(world_size),
)
vllm_result_handler = ResultHandler()

# rank 0 worker runs in the __main__ process.
# all other ranks use one process each.
# vectorlm logic in each ranks (including rank 0) is in a separate thread
# from the vLLM worker logic.
vllm_callback_wrapper = SamplingEngineProvider(
vllm_engine_config,
barriers,
LoRASamplingEngine,
partial(main, vectorlm_config),
)

vllm_callback_wrapper.initialize_engine()
assert vllm_callback_wrapper.llm is not None
output = vllm_callback_wrapper.llm.generate("Vector Institute is")
print(output[0].prompt + output[0].outputs[0].text)

vllm_callback_wrapper.join_vectorlm_thread()
Empty file added profiling/__init__.py
Empty file.
10 changes: 10 additions & 0 deletions vectorlm/sampling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .abstract import AbstractSamplingEngine
from .sampling_lora import LoRASamplingEngine
from .utils import (
ManagedLLM,
ManagedMultiProcGPUExecutor,
SamplingEngineProvider,
SynchronizationBarriers,
handle_sample,
multiprocess_wrap,
)
Loading