diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 9bb84635b7d3b..6b32a52071860 100755 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -33,7 +33,6 @@ else: flashinfer_top_k_top_p_sampling = None -FORCE_GREEDY = os.environ.get('VLLM_FORCE_GREEDY_SAMPLE', '0').lower() in ['1', 'true'] def get_sampler() -> torch.nn.Module: if envs.VLLM_USE_V1: @@ -189,7 +188,6 @@ def __init__(self): # speculative decoding. self.include_gpu_probs_tensor = False self.should_modify_greedy_probs_inplace = False - self.force_greedy_sample = FORCE_GREEDY def _init_sampling_tensors( self, @@ -275,9 +273,8 @@ def forward( # Use float32 to apply temperature scaling. # Use in-place division to avoid creating a new tensor. - if not self.force_greedy_sample: - logits = logits.to(torch.float) - logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) + logits = logits.to(torch.float) + logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None: # If we have a scalar p and k, we can use the optimized version. @@ -294,14 +291,14 @@ def forward( # We use float32 for probabilities and log probabilities. # Compute the probabilities. - probs = None if self.force_greedy_sample else torch.softmax(logits, dim=-1, dtype=torch.float) + probs = torch.softmax(logits, dim=-1, dtype=torch.float) # Compute the log probabilities. - logprobs = None if self.force_greedy_sample else torch.log_softmax(logits, dim=-1, dtype=torch.float) + logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # Sample the next tokens. maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample( - logits if self.force_greedy_sample else probs, - logits if self.force_greedy_sample else logprobs, + probs, + logprobs, sampling_metadata, sampling_tensors, include_gpu_probs_tensor=self.include_gpu_probs_tensor, @@ -313,7 +310,7 @@ def forward( # preserve GPU-side tensors in support of later # deferred pythonization of logprobs assert maybe_sampled_tokens_tensor is not None - on_device_tensors = (logits if self.force_greedy_sample else probs, logits if self.force_greedy_sample else logprobs, maybe_sampled_tokens_tensor) + on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) else: # Since Pythonization has already happened, don't preserve # GPU-side tensors. @@ -327,7 +324,7 @@ def forward( assert not isinstance(maybe_deferred_sample_results, SampleResultArgsType) prompt_logprobs, sample_logprobs = get_logprobs( - logits if self.force_greedy_sample else logprobs, sampling_metadata, maybe_deferred_sample_results) + logprobs, sampling_metadata, maybe_deferred_sample_results) return _build_sampler_output( maybe_deferred_sample_results, @@ -862,6 +859,7 @@ def get_pythonized_sample_results( for i in range(len(sampling_metadata.seq_groups)) ] + def _sample_with_torch( probs: torch.Tensor, logprobs: torch.Tensor, @@ -993,6 +991,7 @@ def _sample_with_torch( sampled_token_ids_tensor, ) + def _sample( probs: torch.Tensor, logprobs: torch.Tensor, @@ -1083,7 +1082,7 @@ def get_logprobs( # The largest requested number of logprobs. We find logprobs as many as the # largest num logprobs in this API. If every logprobs is None, it will be # set to -1. - largest_num_logprobs = -float("inf") if FORCE_GREEDY else -1 # If we skipped the logsoftmax (i.e logprobs is just the logits), then the starting min should be -inf) + largest_num_logprobs = -1 # Select indices to compute logprob from, ranks of token ids, and the top # k token ids from logprobs. diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 24356f81ca696..605c09b8d7225 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -439,8 +439,7 @@ def update_from_generation_config( @cached_property def sampling_type(self) -> SamplingType: - # tricky change: use greedy sampling if top_k == 1 - if self.temperature < _SAMPLING_EPS or self.top_k == 1: + if self.temperature < _SAMPLING_EPS: return SamplingType.GREEDY if self.seed is not None: return SamplingType.RANDOM_SEED diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 2701af9121973..0f8c6ac72539a 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -11,6 +11,7 @@ import math import os import time +import copy from array import array from enum import IntEnum from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, @@ -74,9 +75,6 @@ LORA_WARMUP_RANK = 8 -VLLM_DELAYED_SAMPLING = os.environ.get('VLLM_DELAYED_SAMPLING', 'false').lower() == 'true' -DUMMY_TOKEN_ID = -1 - def subtuple(obj: object, typename: str, @@ -748,10 +746,8 @@ def __init__( raise ValueError( "Speculative decoding is not supported with " "contiguous PA, please set VLLM_CONTIGUOUS_PA=false") - # For both multi-step scheduling and delayed sampling + # For multi-step scheduling self.cached_step_outputs: List[torch.Tensor] = [] - # For delayed sampling - self.cached_step_inputs: List[ModelInputForHPUWithSamplingMetadata] = [] def _set_gc_threshold(self) -> None: # Read https://docs.python.org/3/library/gc.html#gc.set_threshold @@ -888,7 +884,7 @@ def _add_dummy_seq(self, seq_group_metadata_list, is_prompt): for seq_group_metadata in seq_group_metadata_list) temperature = 0.0 if has_greedy_samples else 1.0 dummy_seq_group_metadata = self.create_dummy_seq_group_metadata( - -1, 0, is_prompt, temperature=temperature) + 0, 0, is_prompt, temperature=temperature) seq_group_metadata_list.extend(dummy_seq_group_metadata for _ in range(batch_size_padding)) return seq_group_metadata_list, real_batch_size, batch_size_padded @@ -2089,13 +2085,15 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: ("HabanaWorker.determine_num_available_blocks needs " "to be called before warming up the model.") free_mem = HabanaMemoryProfiler.current_free_device_memory() - graph_free_mem = free_mem + graph_free_mem = free_mem - self.mem_margin graph_free_mem = align_workers(graph_free_mem, torch.distributed.ReduceOp.MIN) prompt_graph_mem_ratio = float( os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.3')) - prompt_available_memory = graph_free_mem - decode_available_memory = graph_free_mem + prompt_available_memory = (prompt_graph_mem_ratio * + graph_free_mem) + decode_available_memory = (graph_free_mem - + prompt_available_memory) msg = ( f"Using {format_bytes(graph_free_mem)}" f"/{format_bytes(free_mem)} " @@ -2399,20 +2397,6 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int], return lora_mask, lora_logits_mask - def _get_seq_ids(self, model_input): - return ([sg.seq_ids[0] - for sg in model_input.sampling_metadata.seq_groups]) - - def _pad_to_max_num_seqs(self, tensor, value): - padding_needed = self.max_num_seqs - tensor.size(0) - if padding_needed: - padding = torch.full((padding_needed, *tensor.shape[1:]), - value, - device=tensor.device, - dtype=tensor.dtype) - tensor = torch.cat([tensor, padding]) - return tensor - @torch.inference_mode() def execute_model( self, @@ -2424,26 +2408,6 @@ def execute_model( previous_hidden_states: Optional[torch.Tensor] = None, seqs=None, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: - use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode - assert not (use_delayed_sampling and num_steps != 1), \ - 'Delayed sampling is not compatible with MSS!' - if use_delayed_sampling and not model_input.is_prompt: - num_cached = len(self.cached_step_outputs) - assert num_cached > 0 - cur_seq_ids = self._get_seq_ids(model_input) - cur_seq_id_pos = {sid: idx for idx, sid in enumerate(cur_seq_ids) if sid >= 0} - htorch.core.mark_step() - for i in range(num_cached): - prev_seq_ids = self._get_seq_ids(self.cached_step_inputs[i]) - target_indices = [cur_seq_id_pos.get(psi, -1) for psi in prev_seq_ids] - padding = self.cached_step_outputs[i].size(0) - len(target_indices) - target_indices.extend([-1] * padding) - target_indices = torch.tensor(target_indices, - device=model_input.input_tokens.device, - dtype=model_input.input_tokens.dtype) - model_input.input_tokens.index_copy_(0, target_indices, self.cached_step_outputs[i]) - htorch.core.mark_step() - if not model_input.is_first_multi_step: if not model_input.is_last_step: # not first or last multi-step @@ -2510,7 +2474,7 @@ def execute_model( f"graphs{'T' if use_graphs else 'F'}") else: model_event_name = 'model_executable' - if num_steps > 1 or use_delayed_sampling: + if num_steps > 1: # in case of multi-step scheduling # we only want to pythonize in the last step sampling_metadata.skip_sampler_cpu_output = True @@ -2575,9 +2539,9 @@ def try_revert_dummy_output_tokens(): if not self.is_driver_worker: continue - if use_delayed_sampling: - fake_output = self._delayed_sampler_outputs(model_input) - + if model_input.async_callback is not None: + model_input.async_callback() + # Sample the next token. with self.profiler.record_event( 'internal', ('sample_' f'{"prompt" if is_prompt else "decode"}_' @@ -2589,16 +2553,9 @@ def try_revert_dummy_output_tokens(): ) if num_steps > 1: output = output.sampled_token_ids - self.cached_step_outputs.append(output) - if use_delayed_sampling: - self._patch_prev_output() - output = self._pad_to_max_num_seqs( - output.sampled_token_ids, DUMMY_TOKEN_ID) - self.cached_step_outputs.append(output) - self.cached_step_inputs.append(model_input) + self.cached_step_outputs.append( + output.detach().clone()) htorch.core.mark_step() - if model_input.async_callback is not None: - model_input.async_callback() if i < num_steps - 1: if i == 0: if model_input.async_callback is not None: @@ -2692,21 +2649,12 @@ def try_revert_dummy_output_tokens(): if model_input.is_prompt: output.prefill_hidden_states = hidden_states output.hidden_states = hidden_states - if use_delayed_sampling: - return [fake_output] - return [output] if self.is_driver_worker else [] else: return [] return output if type(output) is list else [output] - def _delayed_sampler_outputs(self, model_input): - next_token_ids = [[DUMMY_TOKEN_ID]] * len(model_input.sampling_metadata.seq_groups) - sampler_output = self._make_decode_output( - next_token_ids, model_input.sampling_metadata.seq_groups) - return sampler_output - def _decode_sampler_outputs(self, model_input): use_async_out_proc = model_input.async_callback is not None sampler_outputs = [] @@ -2756,24 +2704,3 @@ def _make_decode_output( sampler_outputs.append( CompletionSequenceGroupOutput(seq_outputs, None)) return SamplerOutput(sampler_outputs) - - def _patch_prev_output(self): - assert len(self.cached_step_inputs) == len(self.cached_step_outputs), \ - f'Inputs and outputs are out of sync! {len(self.cached_step_inputs)} vs {len(self.cached_step_outputs)}' - if len(self.cached_step_inputs) == 0: - return - model_input = self.cached_step_inputs.pop(0) - delayed_output = self.cached_step_outputs.pop(0).cpu().squeeze(-1).tolist() - ctx = model_input.async_callback.keywords["ctx"] - assert len(ctx.output_queue) == 1, 'There should be exactly 1 output waiting!' - output_data = ctx.output_queue[0] - assert len(output_data.outputs) == 1 - for fake_out, real_out in zip(output_data.outputs[0], delayed_output): - fake_out.samples[0].output_token = real_out - for sg, real_out in zip(output_data.seq_group_metadata_list, delayed_output): - assert len(sg.seq_data) == 1 - seq_data = list(sg.seq_data.values())[0] - # This is a hack. Assigning output_token_ids triggers - # a cache recomputation and we only need to update the last token - seq_data.output_token_ids_array[-1] = real_out - seq_data._cached_all_token_ids[-1] = real_out diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index 1322aac5af2ba..969971f2e25cd 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -314,7 +314,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: logger.info(msg) # At this point we should've allocated the maximum workspace for all # recipes we will use the extra memory for graphs/blocks - free_hpu_memory = torch.hpu.mem_get_info()[0] * 2.0 + free_hpu_memory = torch.hpu.mem_get_info()[0] cache_block_size = self.get_cache_block_size_bytes() graph_reserved_mem = (float( @@ -336,7 +336,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: f"(VLLM_GRAPH_RESERVED_MEM={graph_reserved_mem}), " f"{format_bytes(cache_size_bytes)} reserved for KV cache") logger.info(msg) - num_hpu_blocks = int(os.environ.get('VLLM_NUM_HPU_BLOCKS', '3072')) + num_hpu_blocks = int(cache_size_bytes // cache_block_size) num_cpu_blocks = int(self.cache_config.swap_space_bytes // cache_block_size) num_hpu_blocks = max(num_hpu_blocks, 0)