diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 55e4e14027f79..5b9930f55a7e6 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -134,15 +134,10 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, + _kv_cache: torch.Tensor, + _attn_metadata: AttentionMetadata, ) -> torch.Tensor: - - if self.use_direct_call: - return self.impl.forward(query, key, value, kv_cache, - attn_metadata, self._k_scale, - self._v_scale) - elif self.use_output: + if self.use_output: output = torch.empty_like(query) hidden_size = query.size(-1) # Reshape the query, key, and value tensors. @@ -154,12 +149,19 @@ def forward( key = key.view(-1, self.num_kv_heads, self.head_size) if value is not None: value = value.view(-1, self.num_kv_heads, self.head_size) - torch.ops.vllm.unified_attention_with_output( - query, key, value, output, self.layer_name) + if self.use_direct_call: + unified_attention_with_output(query, key, value, output, + self.layer_name) + else: + torch.ops.vllm.unified_attention_with_output( + query, key, value, output, self.layer_name) return output.view(-1, hidden_size) else: - return torch.ops.vllm.unified_attention(query, key, value, - self.layer_name) + if self.use_direct_call: + return unified_attention(query, key, value, self.layer_name) + else: + return torch.ops.vllm.unified_attention( + query, key, value, self.layer_name) def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore diff --git a/vllm/utils.py b/vllm/utils.py index 217ccb25cef6d..9a509da3c1ef1 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2171,5 +2171,4 @@ def bind_kv_cache( forward_ctx = ctx[layer_name] assert len(forward_ctx.kv_cache) == len(kv_cache) for ve, ve_kv_cache in enumerate(kv_cache): - assert forward_ctx.kv_cache[ve].numel() == 0 forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx] diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 9d479f412af46..f5d1df1ffab21 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -28,6 +28,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import DeviceConfig, VllmConfig from vllm.distributed.parallel_state import get_world_group +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -1943,7 +1944,11 @@ def execute_model( f"graphs{'T' if use_graphs else 'F'}") else: model_event_name = 'model_executable' - with self.profiler.record_event('internal', model_event_name): + with set_forward_context( + model_input.attn_metadata, self.vllm_config, + model_input.virtual_engine), \ + self.profiler.record_event( + 'internal', model_event_name): hidden_states = self.model.forward( **execute_model_kwargs, selected_token_indices=sampling_metadata.selected_token_indices diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index ae4eb6ba6eaec..a35f5467e1a1f 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -8,6 +8,7 @@ from transformers_neuronx.config import GenerationConfig from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput @@ -314,13 +315,15 @@ def execute_model( raise ValueError( "NeuronModelRunner does not support multi-step execution.") - hidden_states = self.model( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - input_block_ids=model_input.input_block_ids, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), - ) + with set_forward_context(None, self.vllm_config, 0): + hidden_states = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + input_block_ids=model_input.input_block_ids, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs + or {}, + device=self.device), + ) # Compute the logits only if the on-device sampling is turned off as # on-device sampling outputs the token ids. diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index 6000e5dfe4e30..a38b5a4e6e8d5 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -8,6 +8,7 @@ from vllm.attention import get_attn_backend from vllm.attention.backends.openvino import OpenVINOAttentionMetadata from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput @@ -350,7 +351,8 @@ def execute_model( device=self.device), } - hidden_states = model_executable(**execute_model_kwargs) + with set_forward_context(attn_metadata, self.vllm_config, 0): + hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. logits = self.model.compute_logits(hidden_states, sampling_metadata) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 7bdb7f0e2d6a9..52c577bccab9c 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -13,6 +13,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model @@ -265,8 +266,9 @@ def _dummy_run( torch._dynamo.mark_dynamic(t, 0) torch._dynamo.mark_dynamic(p, 0) # Dummy run. - self.model(token_ids, position_ids, attn_metadata, input_lens, t, p, - num_samples, kv_caches) + with set_forward_context(attn_metadata, self.vllm_config, 0): + self.model(token_ids, position_ids, attn_metadata, input_lens, t, + p, num_samples, kv_caches) def warmup_model( self, @@ -663,10 +665,13 @@ def execute_model( input_lens = model_input.input_lens[i:i + 1].to(self.device) t = model_input.t[i:i + 1].to(self.device) p = model_input.p[i:i + 1].to(self.device) - output_token_ids = self.model(token_ids, position_ids, - attn_metadata, input_lens, t, p, - model_input.num_samples, - kv_caches) + with set_forward_context(model_input.attn_metadata, + self.vllm_config, + model_input.virtual_engine): + output_token_ids = self.model(token_ids, position_ids, + attn_metadata, input_lens, t, + p, model_input.num_samples, + kv_caches) next_token_ids.append(output_token_ids[0]) start_idx = end_idx @@ -711,10 +716,13 @@ def execute_model( input_lens = model_input.input_lens.to(self.device) for i in range(num_steps): slot_mapping = attn_metadata.slot_mapping - output_token_ids = self.model(token_ids, position_ids, - attn_metadata, input_lens, t, p, - model_input.num_samples, - kv_caches) + with set_forward_context(model_input.attn_metadata, + self.vllm_config, + model_input.virtual_engine): + output_token_ids = self.model(token_ids, position_ids, + attn_metadata, input_lens, t, + p, model_input.num_samples, + kv_caches) self.cached_step_outputs.append(output_token_ids) if i < num_steps - 1: diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 8754f7538f251..ea0e700545b16 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -12,7 +12,7 @@ from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, LoraNotSupportedWorkerBase, WorkerBase, @@ -108,6 +108,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: torch.tensor([], dtype=torch.float32, device=self.device)) for _ in range(num_layers)] + bind_kv_cache(self.compilation_config.static_forward_context, + [kv_caches]) self.model_runner._dummy_run( batch_size=1, seq_len=self.scheduler_config.max_num_batched_tokens, @@ -170,6 +172,8 @@ def initialize_cache( device="cpu") cpu_v_cache = torch.zeros_like(cpu_k_cache) self.cpu_cache.append((cpu_k_cache, cpu_v_cache)) + bind_kv_cache(self.compilation_config.static_forward_context, + [self.tpu_cache]) self._warmup_model() def _warmup_model(self) -> None: diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 9cf25387560da..82b8f22a5af33 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -12,6 +12,7 @@ from vllm.attention import get_attn_backend from vllm.config import VllmConfig from vllm.distributed import get_pp_group +from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor import SamplingMetadataCache @@ -562,15 +563,17 @@ def execute_model( if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_start_time = time.time() - - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device)) + with set_forward_context(model_input.attn_metadata, self.vllm_config, + model_input.virtual_engine): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs + or {}, + device=self.device)) # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: return hidden_or_intermediate_states