Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][WIP] Hybrid allocator for full attention & sliding window attention interleaved models (Reference PR, do not merge) #11938

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

heheda12345
Copy link
Collaborator

@heheda12345 heheda12345 commented Jan 10, 2025

This pr implements step 1 of #11382 , so that

  1. we won't waste memory on sliding window & full attention interleaved models
  2. support prefix caching of sliding window attention, where the cache hit only requires the tokens inside sliding window not evicted

Benchmark result (accelerate hybrid model & very little overhead on standard full attention models)

  • this pr:
VLLM_USE_V1=1 python3 benchmark_throughput.py --model google/gemma-2-27b-it --input-len 6144 --output-len 1024 --num-prompts 50
Throughput: 0.17 requests/s, 1239.96 total tokens/s, 177.14 output tokens/s
VLLM_USE_V1=1 python3 benchmark_throughput.py --model meta-llama/Llama-3.1-8B-Instruct --input-len 6144 --output-len 1024 --num-prompts 50
Throughput: 1.48 requests/s, 10609.31 total tokens/s, 1515.62 output tokens/s
VLLM_USE_V1=1 python3 benchmark_throughput.py --model google/gemma-2-27b-it --input-len 6144 --output-len 1024 --num-prompts 50
Throughput: 0.15 requests/s, 1077.11 total tokens/s, 153.87 output tokens/s
VLLM_USE_V1=1 python3 benchmark_throughput.py --model meta-llama/Llama-3.1-8B-Instruct --input-len 6144 --output-len 1024 --num-prompts 50
Throughput: 1.49 requests/s, 10682.79 total tokens/s, 1526.11 output tokens/s

Key modifications

  1. kv cache initialization:
  • original workflow:
    1. num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks()
    2. self.model_executor.initialize(num_gpu_blocks) (allocate kv cache)
  • modified workflow
     # Get all kv cache tensor needed via parsing the model
     kv_cache_spec = self.model_executor.get_kv_cache_spec()
     # Get availble_gpu_memory (instead of determine num_blocks based on that) by profile_run
     availble_gpu_memory = self.model_executor.get_available_memory()
     # EngineCore determines the page size & how to create each kv cache tensor
     kv_cache_config, num_gpu_blocks = get_kv_cache_config(
         vllm_config, kv_cache_spec, availble_gpu_memory)
     # Executor initialize the kv_cache based on that decision
     self.model_executor.initialize(kv_cache_config)
  1. grouped allocation
    • original: one KVCacheManager that allocate memory for all layers
    • modified:
      • multiple KVCacheManagers, one for each group of layer (see "group the layers" for detail). All KVCacheManagers have the same page size and allocate memory from the same pool.
      • add group_id to kv_block_hash
      • block_table in worker: add a new dimension for groups
      • two KVCacheManager implementation, for full attention & sliding window attention respectively

I plan to split it into the following prs:

  1. kv cache initialization, as discussed above.
  2. add a new "group" dimension to the block_table, to represent the different memory allocated for different types of kv cache.
  3. change AttentionMetadata to dict[layer_name, AttentionMetadata]
  4. a very large pr implementing HybridKVCacheManager, which is a pluggable alternative with KVCacheManager, and won't touch the code path for standard models with only full attention layers.

Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
(cherry picked from commit 176dc6d)
Signed-off-by: Chen Zhang <[email protected]>
(cherry picked from commit c5a5155)
Signed-off-by: Chen Zhang <[email protected]>
(cherry picked from commit de8324b)
Signed-off-by: Chen Zhang <[email protected]>
(cherry picked from commit fa9b0bb)
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Copy link

mergify bot commented Jan 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @heheda12345.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 10, 2025
@heheda12345 heheda12345 changed the title [V1][WIP] Hybrid allocator for full attention & sliding window attention interleaved models [V1][WIP] Hybrid allocator for full attention & sliding window attention interleaved models (Reference PR, do not merge) Jan 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant