Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Spec Decode] feat: support LoRA with speculative decoding #11966

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

llsj14
Copy link
Contributor

@llsj14 llsj14 commented Jan 12, 2025

Summary

  • This PR is for supporting LoRA with Speculative Decoding.

Implementation

  • There were two problems to solve to apply LoRA in Spec Decode.
  1. The LoRA adapter is mostly designed for the target model and might cause errors such as the following when the same LoRA adapter is applied to the draft model. Therefore, until the API interface is changed to inject the corresponding LoRA adapter for the draft model, the LoRA adapter should be only applied to the target model and temporarily disabled for the draft model.
ERROR 01-12 03:08:10 engine.py:135]     result = super().activate_adapter(lora_id)
ERROR 01-12 03:08:10 engine.py:135]   File "/vllm/vllm/lora/models.py", line 405, in activate_adapter
ERROR 01-12 03:08:10 engine.py:135]     module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
ERROR 01-12 03:08:10 engine.py:135]   File "/vllm/vllm/lora/layers.py", line 215, in set_lora
ERROR 01-12 03:08:10 engine.py:135]     self.lora_b_stacked[index,
ERROR 01-12 03:08:10 engine.py:135] RuntimeError: The size of tensor a (768) must match the size of tensor b (4096) at non-singleton dimension 0
  1. As shown in this Llama model implementation, when LoRA is enabled, the vocabulary size of both the target and draft models is further padded for kernel compatibility. Therefore, we need to adjust the vocabulary size of the SpecDecodeWorker. If this adjustment is not made, the following error occurs:
ERROR 01-12 03:06:35 engine.py:135]   File "/vllm/vllm/lora/models.py", line 405, in activate_adapter
ERROR 01-12 03:06:35 engine.py:135]     module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
ERROR 01-12 03:06:35 engine.py:135]   File "/vllm/vllm/lora/layers.py", line 215, in set_lora
ERROR 01-12 03:06:35 engine.py:135]     self.lora_b_stacked[index,
ERROR 01-12 03:06:35 engine.py:135] RuntimeError: The size of tensor a (768) must match the size of tensor b (4096) at non-singleton dimension 0

Test

  • Running Server
python -m vllm.entrypoints.openai.api_server \
     --model meta-llama/Llama-2-7b-hf \
     --port 8080 \
     --disable-custom-all-reduce \
     --swap-space 0 \
     --gpu-memory-utilization 0.9 \
     --enable-lora \
     --lora-modules sql-lora=$HOME/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/ \
     --speculative_model JackFram/llama-68m \
     --num_speculative_tokens 3
  • Request w/ LoRA Adapter
curl http://localhost:8080/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "sql-lora",
        "prompt": "San Francisco is a",
        "max_tokens": 32,
                "top_k": 1,
                "top_p": 1.0,
        "temperature": 1.0
    }' | jq
  • Response
"text": " city in California, United States, on the tip of a peninsula between the Pacific Ocean and San Francisco Bay. San Francisco is a leading financial center and"
  • Request w/o LoRA adapter
curl http://localhost:8080/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "/mnt/lvm/checkpoints/hugginface/Llama-2-7b-hf",
        "prompt": "San Francisco is a",
        "max_tokens": 32,
                "top_k": 1,
                "top_p": 1.0,
        "temperature": 1.0
    }' | jq
  • Response
"text": " city of many neighborhoods, each with its own distinct personality. San Francisco is a city of many neighborhoods, each with its own distinct personality."

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@llsj14 llsj14 changed the title feat: support LoRA with speculative decoding [Spec Decode] feat: support LoRA with speculative decoding Jan 12, 2025
Signed-off-by: Sungjae Lee <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant