We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
I'm running Llama3 70B with vllm on a TPU-v4-16, when using the flash attention kernel i'm able to go up to 16k, but using multi_queries_paged_attention with sequence length 256, it seems that the page table is taking too much smem. @vanbasten23 @WoosukKwon any idea how to address this (i'm familiar with pallas programming)? maybe something along the lines of this? https://github.com/vllm-project/vllm/blob/02222a0256f60319f5bcd56d1d036a943d6334f8/vllm/attention/backends/pallas.py#L260
Loading safetensors checkpoint shards: 100% Completed | 30/30 [02:03<00:00, 4.13s/it] INFO 12-21 14:11:07 ray_tpu_executor.py:276] # TPU blocks: 19032, # CPU blocks: 6552 INFO 12-21 14:11:07 tpu_model_runner.py:274] Compiling the model with different input shapes... (RayWorkerWrapper pid=777, ip=10.130.0.186) INFO 12-21 14:11:08 tpu_model_runner.py:274] Compiling the model with different input shapes... (RayWorkerWrapper pid=1005) INFO 12-21 14:07:13 tpu.py:27] Cannot use _Backend.FLASH_ATTN backend on TPU. [repeated 6x across cluster] (RayWorkerWrapper pid=1005) INFO 12-21 14:07:13 selector.py:163] Using Pallas backend. [repeated 6x across cluster] (RayWorkerWrapper pid=1005) WARNING 12-21 14:07:13 tpu_worker.py:62] Starting to init distributed environment with config: ParallelConfig(pipeline_parallel_size=1, tensor_parallel_size=8, worker_use_ray=False, max_parallel_loading_workers=None, disable_custom_all_reduce=False, tokenizer_pool_config=None, ray_workers_use_nsight=False, p lacement_group=<ray.util.placement_group.PlacementGroup object at 0x7f05501350f0>, distributed_executor_backend='ray', worker_cls='vllm.worker.tpu_worker.TPUWorker', sd_worker_cls='auto', world_size=8, rank=3) [repeated 6x across cluster] (RayWorkerWrapper pid=1005) INFO 12-21 14:07:13 parallel_state.py:954] world_size=8 rank=3 local_rank=3 distributed_init_method=tcp://10.130.0.185:57577 backend=gloo [repeated 6x across cluster] (RayWorkerWrapper pid=1005) INFO 12-21 14:07:13 parallel_state.py:959] attempting to initialize distributed environment [repeated 6x across cluster] (RayWorkerWrapper pid=1135, ip=10.130.0.186) init_world_group: local_rank=3 [repeated 12x across cluster] (RayWorkerWrapper pid=1135, ip=10.130.0.186) init_world_group: backend='gloo' [repeated 6x across cluster] (RayWorkerWrapper pid=1135, ip=10.130.0.186) init_model_parallel_group bla bla: local_rank=3 [repeated 26x across cluster] (RayWorkerWrapper pid=1135, ip=10.130.0.186) init_model_parallel_group bla bla: backend='gloo' [repeated 13x across cluster] (RayWorkerWrapper pid=1005) self.cpu_group=<torch.distributed.distributed_c10d.ProcessGroup object at 0x7f051028d330> [repeated 6x across cluster] INFO 12-21 14:13:02 tpu_model_runner.py:284] batch_size: 1, seq_len: 16 (RayWorkerWrapper pid=777, ip=10.130.0.186) INFO 12-21 14:13:02 tpu_model_runner.py:284] batch_size: 1, seq_len: 16 (RayWorkerWrapper pid=895, ip=10.130.0.186) INFO 12-21 14:11:08 tpu_model_runner.py:274] Compiling the model with different input shapes... [repeated 6x across cluster] INFO 12-21 14:13:05 tpu_model_runner.py:284] batch_size: 1, seq_len: 32 INFO 12-21 14:13:07 tpu_model_runner.py:284] batch_size: 1, seq_len: 64 (RayWorkerWrapper pid=995) INFO 12-21 14:13:07 tpu_model_runner.py:284] batch_size: 1, seq_len: 64 [repeated 18x across cluster] INFO 12-21 14:13:10 tpu_model_runner.py:284] batch_size: 1, seq_len: 128 INFO 12-21 14:13:12 tpu_model_runner.py:284] batch_size: 1, seq_len: 256 (RayWorkerWrapper pid=777, ip=10.130.0.186) INFO 12-21 14:13:12 tpu_model_runner.py:284] batch_size: 1, seq_len: 256 [repeated 10x across cluster] INFO 12-21 14:13:15 tpu_model_runner.py:284] batch_size: 1, seq_len: 512 INFO 12-21 14:13:19 tpu_model_runner.py:284] batch_size: 1, seq_len: 1024 (RayWorkerWrapper pid=995) INFO 12-21 14:13:19 tpu_model_runner.py:284] batch_size: 1, seq_len: 1024 [repeated 14x across cluster] INFO 12-21 14:13:22 tpu_model_runner.py:284] batch_size: 1, seq_len: 2048 INFO 12-21 14:13:27 tpu_model_runner.py:284] batch_size: 1, seq_len: 4096 (RayWorkerWrapper pid=995) INFO 12-21 14:13:27 tpu_model_runner.py:284] batch_size: 1, seq_len: 4096 [repeated 14x across cluster] INFO 12-21 14:13:32 tpu_model_runner.py:284] batch_size: 1, seq_len: 8192 (RayWorkerWrapper pid=995) INFO 12-21 14:13:32 tpu_model_runner.py:284] batch_size: 1, seq_len: 8192 [repeated 7x across cluster] INFO 12-21 14:13:38 tpu_model_runner.py:284] batch_size: 1, seq_len: 16384 INFO 12-21 14:13:38 tpu_model_runner.py:291] Compilation for prefill done in 150.46 s. INFO 12-21 14:13:38 tpu_model_runner.py:295] Compiling the model with different input shapes for prefix prefill... (RayWorkerWrapper pid=777, ip=10.130.0.186) INFO 12-21 14:13:38 tpu_model_runner.py:291] Compilation for prefill done in 149.52 s. (RayWorkerWrapper pid=777, ip=10.130.0.186) INFO 12-21 14:13:38 tpu_model_runner.py:295] Compiling the model with different input shapes for prefix prefill... (RayWorkerWrapper pid=777, ip=10.130.0.186) INFO 12-21 14:13:38 tpu_model_runner.py:284] batch_size: 1, seq_len: 16384 [repeated 7x across cluster] BINFO 12-21 14:15:53 tpu_model_runner.py:306] batch_size: 1, seq_len: 16 (RayWorkerWrapper pid=1005) INFO 12-21 14:13:38 tpu_model_runner.py:291] Compilation for prefill done in 149.50 s. [repeated 6x across cluster] (RayWorkerWrapper pid=1005) INFO 12-21 14:13:38 tpu_model_runner.py:295] Compiling the model with different input shapes for prefix prefill... [repeated 6x across cluster] (RayWorkerWrapper pid=995) INFO 12-21 14:15:53 tpu_model_runner.py:306] batch_size: 1, seq_len: 16 [repeated 7x across cluster] INFO 12-21 14:16:31 tpu_model_runner.py:306] batch_size: 1, seq_len: 32 (RayWorkerWrapper pid=777, ip=10.130.0.186) INFO 12-21 14:16:31 tpu_model_runner.py:306] batch_size: 1, seq_len: 32 [repeated 7x across cluster] INFO 12-21 14:17:07 tpu_model_runner.py:306] batch_size: 1, seq_len: 64 (RayWorkerWrapper pid=995) INFO 12-21 14:17:07 tpu_model_runner.py:306] batch_size: 1, seq_len: 64 [repeated 7x across cluster] INFO 12-21 14:17:48 tpu_model_runner.py:306] batch_size: 1, seq_len: 128 (RayWorkerWrapper pid=777, ip=10.130.0.186) INFO 12-21 14:17:48 tpu_model_runner.py:306] batch_size: 1, seq_len: 128 [repeated 7x across cluster] INFO 12-21 14:18:30 tpu_model_runner.py:306] batch_size: 1, seq_len: 256 (RayWorkerWrapper pid=895, ip=10.130.0.186) INFO 12-21 14:18:30 tpu_model_runner.py:306] batch_size: 1, seq_len: 256 [repeated 7x across cluster]
The text was updated successfully, but these errors were encountered:
No branches or pull requests
I'm running Llama3 70B with vllm on a TPU-v4-16, when using the flash attention kernel i'm able to go up to 16k, but using multi_queries_paged_attention with sequence length 256, it seems that the page table is taking too much smem.
@vanbasten23 @WoosukKwon any idea how to address this (i'm familiar with pallas programming)?
maybe something along the lines of this? https://github.com/vllm-project/vllm/blob/02222a0256f60319f5bcd56d1d036a943d6334f8/vllm/attention/backends/pallas.py#L260
The text was updated successfully, but these errors were encountered: