Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Skip sampling softmax" #806

Merged
merged 1 commit into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
else:
flashinfer_top_k_top_p_sampling = None

FORCE_GREEDY = os.environ.get('VLLM_FORCE_GREEDY_SAMPLE', '0').lower() in ['1', 'true']

def get_sampler() -> torch.nn.Module:
if envs.VLLM_USE_V1:
Expand Down Expand Up @@ -189,7 +188,6 @@ def __init__(self):
# speculative decoding.
self.include_gpu_probs_tensor = False
self.should_modify_greedy_probs_inplace = False
self.force_greedy_sample = FORCE_GREEDY

def _init_sampling_tensors(
self,
Expand Down Expand Up @@ -275,9 +273,8 @@ def forward(

# Use float32 to apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
if not self.force_greedy_sample:
logits = logits.to(torch.float)
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
logits = logits.to(torch.float)
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))

if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
# If we have a scalar p and k, we can use the optimized version.
Expand All @@ -294,14 +291,14 @@ def forward(

# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = None if self.force_greedy_sample else torch.softmax(logits, dim=-1, dtype=torch.float)
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities.
logprobs = None if self.force_greedy_sample else torch.log_softmax(logits, dim=-1, dtype=torch.float)
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)

# Sample the next tokens.
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
logits if self.force_greedy_sample else probs,
logits if self.force_greedy_sample else logprobs,
probs,
logprobs,
sampling_metadata,
sampling_tensors,
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
Expand All @@ -313,7 +310,7 @@ def forward(
# preserve GPU-side tensors in support of later
# deferred pythonization of logprobs
assert maybe_sampled_tokens_tensor is not None
on_device_tensors = (logits if self.force_greedy_sample else probs, logits if self.force_greedy_sample else logprobs, maybe_sampled_tokens_tensor)
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
else:
# Since Pythonization has already happened, don't preserve
# GPU-side tensors.
Expand All @@ -327,7 +324,7 @@ def forward(
assert not isinstance(maybe_deferred_sample_results,
SampleResultArgsType)
prompt_logprobs, sample_logprobs = get_logprobs(
logits if self.force_greedy_sample else logprobs, sampling_metadata, maybe_deferred_sample_results)
logprobs, sampling_metadata, maybe_deferred_sample_results)

return _build_sampler_output(
maybe_deferred_sample_results,
Expand Down Expand Up @@ -862,6 +859,7 @@ def get_pythonized_sample_results(
for i in range(len(sampling_metadata.seq_groups))
]


def _sample_with_torch(
probs: torch.Tensor,
logprobs: torch.Tensor,
Expand Down Expand Up @@ -993,6 +991,7 @@ def _sample_with_torch(
sampled_token_ids_tensor,
)


def _sample(
probs: torch.Tensor,
logprobs: torch.Tensor,
Expand Down Expand Up @@ -1083,7 +1082,7 @@ def get_logprobs(
# The largest requested number of logprobs. We find logprobs as many as the
# largest num logprobs in this API. If every logprobs is None, it will be
# set to -1.
largest_num_logprobs = -float("inf") if FORCE_GREEDY else -1 # If we skipped the logsoftmax (i.e logprobs is just the logits), then the starting min should be -inf)
largest_num_logprobs = -1

# Select indices to compute logprob from, ranks of token ids, and the top
# k token ids from logprobs.
Expand Down
3 changes: 1 addition & 2 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,7 @@ def update_from_generation_config(

@cached_property
def sampling_type(self) -> SamplingType:
# tricky change: use greedy sampling if top_k == 1
if self.temperature < _SAMPLING_EPS or self.top_k == 1:
if self.temperature < _SAMPLING_EPS:
return SamplingType.GREEDY
if self.seed is not None:
return SamplingType.RANDOM_SEED
Expand Down
101 changes: 14 additions & 87 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import math
import os
import time
import copy
from array import array
from enum import IntEnum
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
Expand Down Expand Up @@ -74,9 +75,6 @@

LORA_WARMUP_RANK = 8

VLLM_DELAYED_SAMPLING = os.environ.get('VLLM_DELAYED_SAMPLING', 'false').lower() == 'true'
DUMMY_TOKEN_ID = -1


def subtuple(obj: object,
typename: str,
Expand Down Expand Up @@ -748,10 +746,8 @@ def __init__(
raise ValueError(
"Speculative decoding is not supported with "
"contiguous PA, please set VLLM_CONTIGUOUS_PA=false")
# For both multi-step scheduling and delayed sampling
# For multi-step scheduling
self.cached_step_outputs: List[torch.Tensor] = []
# For delayed sampling
self.cached_step_inputs: List[ModelInputForHPUWithSamplingMetadata] = []

def _set_gc_threshold(self) -> None:
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
Expand Down Expand Up @@ -888,7 +884,7 @@ def _add_dummy_seq(self, seq_group_metadata_list, is_prompt):
for seq_group_metadata in seq_group_metadata_list)
temperature = 0.0 if has_greedy_samples else 1.0
dummy_seq_group_metadata = self.create_dummy_seq_group_metadata(
-1, 0, is_prompt, temperature=temperature)
0, 0, is_prompt, temperature=temperature)
seq_group_metadata_list.extend(dummy_seq_group_metadata
for _ in range(batch_size_padding))
return seq_group_metadata_list, real_batch_size, batch_size_padded
Expand Down Expand Up @@ -2089,13 +2085,15 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
("HabanaWorker.determine_num_available_blocks needs "
"to be called before warming up the model.")
free_mem = HabanaMemoryProfiler.current_free_device_memory()
graph_free_mem = free_mem
graph_free_mem = free_mem - self.mem_margin
graph_free_mem = align_workers(graph_free_mem,
torch.distributed.ReduceOp.MIN)
prompt_graph_mem_ratio = float(
os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.3'))
prompt_available_memory = graph_free_mem
decode_available_memory = graph_free_mem
prompt_available_memory = (prompt_graph_mem_ratio *
graph_free_mem)
decode_available_memory = (graph_free_mem -
prompt_available_memory)
msg = (
f"Using {format_bytes(graph_free_mem)}"
f"/{format_bytes(free_mem)} "
Expand Down Expand Up @@ -2399,20 +2397,6 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],

return lora_mask, lora_logits_mask

def _get_seq_ids(self, model_input):
return ([sg.seq_ids[0]
for sg in model_input.sampling_metadata.seq_groups])

def _pad_to_max_num_seqs(self, tensor, value):
padding_needed = self.max_num_seqs - tensor.size(0)
if padding_needed:
padding = torch.full((padding_needed, *tensor.shape[1:]),
value,
device=tensor.device,
dtype=tensor.dtype)
tensor = torch.cat([tensor, padding])
return tensor

@torch.inference_mode()
def execute_model(
self,
Expand All @@ -2424,26 +2408,6 @@ def execute_model(
previous_hidden_states: Optional[torch.Tensor] = None,
seqs=None,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode
assert not (use_delayed_sampling and num_steps != 1), \
'Delayed sampling is not compatible with MSS!'
if use_delayed_sampling and not model_input.is_prompt:
num_cached = len(self.cached_step_outputs)
assert num_cached > 0
cur_seq_ids = self._get_seq_ids(model_input)
cur_seq_id_pos = {sid: idx for idx, sid in enumerate(cur_seq_ids) if sid >= 0}
htorch.core.mark_step()
for i in range(num_cached):
prev_seq_ids = self._get_seq_ids(self.cached_step_inputs[i])
target_indices = [cur_seq_id_pos.get(psi, -1) for psi in prev_seq_ids]
padding = self.cached_step_outputs[i].size(0) - len(target_indices)
target_indices.extend([-1] * padding)
target_indices = torch.tensor(target_indices,
device=model_input.input_tokens.device,
dtype=model_input.input_tokens.dtype)
model_input.input_tokens.index_copy_(0, target_indices, self.cached_step_outputs[i])
htorch.core.mark_step()

if not model_input.is_first_multi_step:
if not model_input.is_last_step:
# not first or last multi-step
Expand Down Expand Up @@ -2510,7 +2474,7 @@ def execute_model(
f"graphs{'T' if use_graphs else 'F'}")
else:
model_event_name = 'model_executable'
if num_steps > 1 or use_delayed_sampling:
if num_steps > 1:
# in case of multi-step scheduling
# we only want to pythonize in the last step
sampling_metadata.skip_sampler_cpu_output = True
Expand Down Expand Up @@ -2575,9 +2539,9 @@ def try_revert_dummy_output_tokens():
if not self.is_driver_worker:
continue

if use_delayed_sampling:
fake_output = self._delayed_sampler_outputs(model_input)

if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
with self.profiler.record_event(
'internal', ('sample_'
f'{"prompt" if is_prompt else "decode"}_'
Expand All @@ -2589,16 +2553,9 @@ def try_revert_dummy_output_tokens():
)
if num_steps > 1:
output = output.sampled_token_ids
self.cached_step_outputs.append(output)
if use_delayed_sampling:
self._patch_prev_output()
output = self._pad_to_max_num_seqs(
output.sampled_token_ids, DUMMY_TOKEN_ID)
self.cached_step_outputs.append(output)
self.cached_step_inputs.append(model_input)
self.cached_step_outputs.append(
output.detach().clone())
htorch.core.mark_step()
if model_input.async_callback is not None:
model_input.async_callback()
if i < num_steps - 1:
if i == 0:
if model_input.async_callback is not None:
Expand Down Expand Up @@ -2692,21 +2649,12 @@ def try_revert_dummy_output_tokens():
if model_input.is_prompt:
output.prefill_hidden_states = hidden_states
output.hidden_states = hidden_states
if use_delayed_sampling:
return [fake_output]

return [output] if self.is_driver_worker else []
else:
return []

return output if type(output) is list else [output]

def _delayed_sampler_outputs(self, model_input):
next_token_ids = [[DUMMY_TOKEN_ID]] * len(model_input.sampling_metadata.seq_groups)
sampler_output = self._make_decode_output(
next_token_ids, model_input.sampling_metadata.seq_groups)
return sampler_output

def _decode_sampler_outputs(self, model_input):
use_async_out_proc = model_input.async_callback is not None
sampler_outputs = []
Expand Down Expand Up @@ -2756,24 +2704,3 @@ def _make_decode_output(
sampler_outputs.append(
CompletionSequenceGroupOutput(seq_outputs, None))
return SamplerOutput(sampler_outputs)

def _patch_prev_output(self):
assert len(self.cached_step_inputs) == len(self.cached_step_outputs), \
f'Inputs and outputs are out of sync! {len(self.cached_step_inputs)} vs {len(self.cached_step_outputs)}'
if len(self.cached_step_inputs) == 0:
return
model_input = self.cached_step_inputs.pop(0)
delayed_output = self.cached_step_outputs.pop(0).cpu().squeeze(-1).tolist()
ctx = model_input.async_callback.keywords["ctx"]
assert len(ctx.output_queue) == 1, 'There should be exactly 1 output waiting!'
output_data = ctx.output_queue[0]
assert len(output_data.outputs) == 1
for fake_out, real_out in zip(output_data.outputs[0], delayed_output):
fake_out.samples[0].output_token = real_out
for sg, real_out in zip(output_data.seq_group_metadata_list, delayed_output):
assert len(sg.seq_data) == 1
seq_data = list(sg.seq_data.values())[0]
# This is a hack. Assigning output_token_ids triggers
# a cache recomputation and we only need to update the last token
seq_data.output_token_ids_array[-1] = real_out
seq_data._cached_all_token_ids[-1] = real_out
4 changes: 2 additions & 2 deletions vllm/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
logger.info(msg)
# At this point we should've allocated the maximum workspace for all
# recipes we will use the extra memory for graphs/blocks
free_hpu_memory = torch.hpu.mem_get_info()[0] * 2.0
free_hpu_memory = torch.hpu.mem_get_info()[0]

cache_block_size = self.get_cache_block_size_bytes()
graph_reserved_mem = (float(
Expand All @@ -336,7 +336,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
f"(VLLM_GRAPH_RESERVED_MEM={graph_reserved_mem}), "
f"{format_bytes(cache_size_bytes)} reserved for KV cache")
logger.info(msg)
num_hpu_blocks = int(os.environ.get('VLLM_NUM_HPU_BLOCKS', '3072'))
num_hpu_blocks = int(cache_size_bytes // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
num_hpu_blocks = max(num_hpu_blocks, 0)
Expand Down
Loading