From 233b4a652492e3957ee9919eeede06186e9834b4 Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Wed, 6 Nov 2024 18:21:40 +0100 Subject: [PATCH] sequence level processing -> batch level processing (#62) ### sequence level processing -> batch level processing In this PR the code for preparing the input tensors for the AIU is completely rewritten based on the assumption that we have to finish the current decoding on AIU before doing another prefill. Changes: * [rewriting](https://github.ibm.com/ai-foundation/vllm/pull/62/commits/cea122c220b18e3de3dce95faa5e03fe3efe0835) `sendnn_model_runner.py`, `sendnn_worker.py` and `sendnn.py` based on the above constraint. * [removing](https://github.ibm.com/ai-foundation/vllm/pull/62/commits/6869231d83734d3c03ffd15bc6754c1857d063cc) class variable `self._padded_batch_size` since other solution implemented * [removing](https://github.ibm.com/ai-foundation/vllm/pull/62/commits/ff9ebf6923fd9ac6c99e64dfffc7763f6c194399) the unused `input_block_ids` since AIU does not support paged attention yet. * [removing](https://github.ibm.com/ai-foundation/vllm/pull/62/commits/a6d63899bf3d9fae59edde414b8bd2a3c56bc8c7) some unused function arguments in model loading * [removing](https://github.ibm.com/ai-foundation/vllm/pull/62/commits/4527300ee9be4dd1fb76007fb6e0862b97d51676) unused function _get_model_architecture() and global variable `_SENDNN_SUPPORTED_MODELS` The code has been tested in client/server mode for the `llama 194m` and `granite 3b` on `AIU` and `CPU`. --- vllm/core/scheduler.py | 7 +- vllm/model_executor/model_loader/sendnn.py | 32 +- vllm/worker/sendnn_model_runner.py | 335 +++++---------------- vllm/worker/sendnn_worker.py | 64 +--- 4 files changed, 104 insertions(+), 334 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index bfa1bd9ca..be143fb2e 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -895,10 +895,13 @@ def _schedule_prefills( logger.warning( "Input prompt (%d tokens) is too long" " and exceeds limit of %d", num_new_tokens, prompt_limit) - if cond2 := self.scheduler_config.spyre_scheduling_enabled and num_new_tokens > max(self.scheduler_config.spyre_warmup_shapes.keys()): + if cond2 := self.scheduler_config.spyre_scheduling_enabled and num_new_tokens > max( + self.scheduler_config.spyre_warmup_shapes.keys()): logger.warning( "Input prompt (%d tokens) is too long" - " and exceeds maximum padding length of %d", num_new_tokens, max(self.scheduler_config.spyre_warmup_shapes.keys())) + " and exceeds maximum padding length of %d", + num_new_tokens, + max(self.scheduler_config.spyre_warmup_shapes.keys())) if cond1 or cond2: for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED diff --git a/vllm/model_executor/model_loader/sendnn.py b/vllm/model_executor/model_loader/sendnn.py index 7756977e7..36c9db892 100644 --- a/vllm/model_executor/model_loader/sendnn.py +++ b/vllm/model_executor/model_loader/sendnn.py @@ -1,6 +1,6 @@ """Utilities for selecting and loading SENDNN models.""" import sys -from typing import Dict, Optional, Tuple, List +from typing import Optional, List import torch import torch.nn as nn @@ -9,7 +9,7 @@ from fms.models import get_model import vllm.envs as envs -from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, DeviceConfig +from vllm.config import ModelConfig, ParallelConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -31,14 +31,6 @@ print("WARNING: Disabled: dynamo_tracer") pass -# Models supported by SENDNN. -_SENDNN_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = { - # "LlamaForCausalLM": ("transformers_neuronx.llama.model", - # "LlamaForSampling", "LlamaForCausalLM"), - # "MistralForCausalLM": ("transformers_neuronx.mistral.model", - # "MistralForSampling", "MistralForCausalLM") -} - BACKEND_LIST = ['sendnn_decoder', 'inductor'] # used as baseline the following code for llama 7b @@ -71,7 +63,6 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, masks: torch.Tensor, - input_block_ids: torch.Tensor, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> torch.Tensor: @@ -97,8 +88,8 @@ def forward( logits, past_key_value_states = output self.past_key_value_states = past_key_value_states - # removing batch padding again to compute logits - batch_size = input_block_ids.shape[0] + # removing batch padding sequences to compute logits + batch_size = input_ids.shape[0] logits = logits[:batch_size - self.num_padded_sequences] @@ -201,21 +192,8 @@ def load_weights(self, model_config: ModelConfig, max_prompt_length: int, backend=envs.VLLM_SPYRE_DYNAMO_BACKEND) -def _get_model_architecture(config: PretrainedConfig) -> str: - architectures = getattr(config, "architectures", []) - for arch in architectures: - if arch in _SENDNN_SUPPORTED_MODELS: - return arch - raise ValueError( - f"Model architectures {architectures} are not supported on SENDNN " - f"for now. Supported architectures: " - f"{list(_SENDNN_SUPPORTED_MODELS.keys())}") - - def get_sendnn_model(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, max_prompt_length, + parallel_config: ParallelConfig, max_prompt_length, max_decode_length) -> nn.Module: # Create a model instance. diff --git a/vllm/worker/sendnn_model_runner.py b/vllm/worker/sendnn_model_runner.py index d5cead7b2..2b1a30917 100644 --- a/vllm/worker/sendnn_model_runner.py +++ b/vllm/worker/sendnn_model_runner.py @@ -11,7 +11,7 @@ from vllm.model_executor.model_loader.sendnn import get_sendnn_model from vllm.sequence import SequenceGroupMetadata from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.utils import is_pin_memory_available, make_tensor_with_pad +from vllm.utils import is_pin_memory_available logger = init_logger(__name__) @@ -43,13 +43,10 @@ def __init__( self._prompt_lens = [64] self._num_decode_tokens = [20] self._batch_sizes = [1] - # will be set accordingly in prefill phase - self._padded_batch_size = self._batch_sizes[0] - # key: request_id, value: position_ids of sequence - self._position_ids = dict() - # key: request_id, value: attention mask of sequence - self._mask = dict() - + # position_ids of all the sequences in current batch + self._position_ids: torch.Tensor = None + # attention masks of all the sequences in current batch + self._mask: torch.Tensor = None # Lazy initialization: after load_model. self.model: nn.Module @@ -63,36 +60,25 @@ def load_model(self, self._num_decode_tokens = num_decode_tokens if batch_sizes: self._batch_sizes = batch_sizes - max_pad_lenght = max(self._prompt_lens) + max_pad_length = max(self._prompt_lens) max_decode_length = max(self._num_decode_tokens) self.model = get_sendnn_model(self.model_config, parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - max_prompt_length=max_pad_lenght, + max_prompt_length=max_pad_length, max_decode_length=max_decode_length) def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - List[int]]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: assert len(seq_group_metadata_list) > 0 - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] - input_masks: List[torch.Tensor] = [] - input_block_ids: List[int] = [] - seq_lens: List[int] = [] + input_token_list: List[torch.Tensor] = [] # find max prompt length among sequences for padding - max_prompt_len_batch = 0 - for seq_group_metadata in seq_group_metadata_list: - seq_data = seq_group_metadata.seq_data[list( - seq_group_metadata.seq_data.keys())[0]] - # retrieve initial (unpadded) tokens - prompt_tokens = seq_data.get_token_ids() - max_prompt_len_batch = max(max_prompt_len_batch, - len(prompt_tokens)) + max_prompt_len_batch = max( + len(seq_group_metadata.seq_data[next( + iter(seq_group_metadata.seq_data))].get_token_ids()) + for seq_group_metadata in seq_group_metadata_list) # find next bigger compiled padding length large_enough_pads = [ @@ -102,11 +88,17 @@ def _prepare_prompt( # TODO: should be removed in the future: scheduler should intercept and fail too large requests. min_pad_length_batch = max(self._prompt_lens) print( - f'[SENDNNModelRunner] ERROR: Request is too large ({len(max_prompt_len_batch)} tokens), cutting to {min_pad_length_batch} tokens.' + f'[SENDNNModelRunner] ERROR: Request is too large ({len(max_prompt_len_batch)} tokens), truncating to {min_pad_length_batch} tokens.' ) else: min_pad_length_batch = min(large_enough_pads) + # calculating the max number of decode tokens for the required padding + max_num_decode_tokens = max([ + i for (i, v) in zip(self._num_decode_tokens, self._prompt_lens) + if v == min_pad_length_batch + ]) + for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -114,66 +106,27 @@ def _prepare_prompt( seq_id = seq_ids[0] seq_data = seq_group_metadata.seq_data[seq_id] - request_id = seq_group_metadata.request_id # retrieve initial (unpadded) tokens prompt_tokens = seq_data.get_token_ids() + # truncating prompt for warmed up prompt length + prompt_tokens = prompt_tokens[:min_pad_length_batch] # verifying that the number of requested output tokens is not greater the (prompt-specific) number of output tokens we compiled in warmup - max_num_decode_tokens = max([ - i for (i, v) in zip(self._num_decode_tokens, self._prompt_lens) - if v == min_pad_length_batch - ]) - if seq_group_metadata.sampling_params.max_tokens > max_num_decode_tokens: + # TODO: raise error/fail request: scheduler should intercept and fail too many requested output tokens. print(( f"[SENDNNModelRunner] ERROR: Requested number of output tokens ({seq_group_metadata.sampling_params.max_tokens}) bigger than the " f"number of output tokens the model was compiled for during warmup ({max_num_decode_tokens}).\n-> Capping at {max_num_decode_tokens} output tokens!" )) seq_group_metadata.sampling_params.max_tokens = max_num_decode_tokens - # should be multiples of 64 for AIU - # TODO: raise error/fail request: scheduler should intercept and fail too large requests. - prompt_tokens = prompt_tokens[:min_pad_length_batch] - - # do padding, get new token_ids, position_ids and masks - prompt_token_tensor = torch.tensor(prompt_tokens, - dtype=torch.long, - device=torch.device("cpu")) - - if min_pad_length_batch > len(prompt_tokens): - print( - f'[SENDNNModelRunner] INFO: Padding request of length {len(prompt_tokens)} tokens to {min_pad_length_batch} tokens.' - ) - - prompt_token_padded_tensor, padding_kwargs = self.pad_input_ids( - [prompt_token_tensor], min_pad_length=min_pad_length_batch) - - prompt_token_padded = prompt_token_padded_tensor.tolist()[0] - - # set padded position ids for request_id - self._position_ids[request_id] = padding_kwargs['position_ids'][ - 0].tolist() - # set padding attention mask for request_id - self._mask[request_id] = padding_kwargs['mask'][0] - - input_tokens.append(prompt_token_padded) - - seq_len = len(prompt_token_padded) - seq_lens.append(seq_len) - - input_positions.append(self._position_ids[request_id]) - - input_masks.append(self._mask[request_id]) - - assert seq_group_metadata.block_tables is not None - block_table = seq_group_metadata.block_tables[seq_id] - assert len(block_table) == 1 - input_block_ids.append(block_table[0]) - - max_seq_len = max(seq_lens) - assert max_seq_len > 0 + input_token_list.append( + torch.tensor(prompt_tokens, + dtype=torch.long, + device=torch.device("cpu"))) actual_batch_size = len(seq_group_metadata_list) + # find next bigger compiled batch size to patch to given the prompt size matches as well large_enough_batch_sizes = [ b for (b, p) in zip(self._batch_sizes, self._prompt_lens) @@ -187,82 +140,38 @@ def _prepare_prompt( print( f'[SENDNNModelRunner] ERROR: Batch is too large ({actual_batch_size} sequences) for compiled batchsizes:{compiled_batch_sizes}' ) - # TODO: handling? current implementation will crash the server on AIU... + # TODO: how do we handle this? Current implementation will crash the server on AIU... padded_batch_size = actual_batch_size else: padded_batch_size = min(large_enough_batch_sizes) - # store batch size for decode phase - self._padded_batch_size = padded_batch_size - self.model.num_padded_sequences = 0 + # set number of added padding sequences used for computing logits + self.model.num_padded_sequences = padded_batch_size - len( + input_token_list) - # padding to batch size - if padded_batch_size > actual_batch_size: - print( - f'[SENDNNModelRunner] INFO: Batch of size {actual_batch_size} is padded to compiled batchsize {padded_batch_size}' - ) - # preparing batch padding token_ids, position_ids, masks and block_ids - num_batch_pads = padded_batch_size - actual_batch_size - - # idea: give it a single token, rest will be padded: less computations? - input_tokens_pad = torch.tensor([0], - dtype=torch.long, - device=torch.device("cpu")) - - input_tokens_pad_tensor, padding_kwargs_pad = self.pad_input_ids( - [input_tokens_pad], min_pad_length=min_pad_length_batch) - - input_tokens_pad = input_tokens_pad_tensor.tolist()[0] - - # set padded position ids for request_id ='padding_request_id' - self._position_ids['padding_request_id'] = padding_kwargs_pad[ - 'position_ids'][0].tolist() - - # set padding attention mask for request_id = 'padding_request_id' - self._mask['padding_request_id'] = padding_kwargs_pad['mask'][0] - - # append needed batch dimensions - for i in range(num_batch_pads): - # token ids - input_tokens.append(input_tokens_pad) - seq_lens.append(max_seq_len) - # position ids - input_positions.append( - self._position_ids['padding_request_id']) - # masks - input_masks.append(self._mask['padding_request_id']) - # block ids: no usage on AIU yet - input_block_ids.append(0) - # increase padded batches counter - self.model.num_padded_sequences += 1 - - input_tokens = make_tensor_with_pad(input_tokens, - pad=0, - dtype=torch.long, - max_len=max_seq_len, - device=self.device) - input_positions = make_tensor_with_pad(input_positions, - pad=0, - dtype=torch.long, - max_len=max_seq_len, - device=self.device) - input_masks = torch.stack(input_masks, dim=0) - input_block_ids = torch.tensor(input_block_ids, - dtype=torch.long, - device=self.device) + # padding to compiled batch size + while len(input_token_list) < padded_batch_size: + input_token_list.append( + torch.zeros(min_pad_length_batch, + dtype=torch.long, + device=torch.device("cpu"))) + + # get position ids and attention mask + input_tokens, self._position_ids, self._mask = self.pad_input_ids( + input_token_list, min_pad_length=min_pad_length_batch) - return input_tokens, input_positions, input_masks, input_block_ids, seq_lens + seq_lens = [ + input_tokens.shape[1] for i in range(input_tokens.shape[0]) + ] + + return input_tokens, self._position_ids, self._mask, seq_lens def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] - input_masks: List[torch.Tensor] = [] - input_block_ids: List[int] = [] - context_lens: List[int] = [] for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt @@ -271,107 +180,37 @@ def _prepare_decode( seq_id = seq_ids[0] seq_data = seq_group_metadata.seq_data[seq_id] - request_id = seq_group_metadata.request_id generation_token = seq_data.get_last_token_id() input_tokens.append([generation_token]) - seq_len = seq_data.get_len() - - position_id = self._position_ids[request_id][-1] + 1 - # append new position to sequence - self._position_ids[request_id] = self._position_ids[request_id] + [ - position_id - ] - input_positions.append([position_id]) - - self._update_mask(request_id) - input_masks.append(self._mask[request_id]) - - context_lens.append(seq_len) - - assert seq_group_metadata.block_tables is not None - block_table = seq_group_metadata.block_tables[seq_id] - assert len(block_table) == 1 - input_block_ids.append(block_table[0]) + # padding to compiled batch size + actual_batch_size = len(seq_group_metadata_list) + padded_batch_size = self._position_ids.shape[0] + while actual_batch_size < padded_batch_size: + input_tokens.append([0]) + actual_batch_size += 1 - # delete attention masks and positions ids in last decoding step to free memory - # TODO ysc: add condition when reaching eos token. - if seq_data.get_output_len( - ) == seq_group_metadata.sampling_params.max_tokens - 1: - # delete attention mask and position ids for corresponding request_id - del self._mask[request_id] - del self._position_ids[request_id] + # update position ids and attention mask + self._update_position_ids() + self._update_mask() - actual_batch_size = len(seq_group_metadata_list) - # getting batch size we padded to in prefill stage - padded_batch_size = self._padded_batch_size - - # padding to batch size - if padded_batch_size > actual_batch_size: - # preparing batch padding token_ids, position_ids, masks and block_ids - num_batch_pads = padded_batch_size - actual_batch_size - - # token_ids and position_ids - token_id_pad = [0] - position_id_pad = [ - self._position_ids['padding_request_id'][-1] + 1 - ] - # update position ids and mask - self._position_ids['padding_request_id'] = self._position_ids[ - 'padding_request_id'] + position_id_pad - self._update_mask('padding_request_id') - - # append needed batch dimensions - for i in range(num_batch_pads): - # token ids - input_tokens.append(token_id_pad) - # position ids - input_positions.append(position_id_pad) - # masks - input_masks.append(self._mask['padding_request_id']) - # padding sequence has context length 0 - context_lens.append(0) - # block ids: no usage on AIU yet - input_block_ids.append(0) - - # delete attention masks and position ids of batch padding in last decoding step to free memory - if len(self._mask) == 1 and len(self._position_ids) == 1: - # if batch padding was applied and there is only one remaining entry -> end of decoding -> delete padding entry - del self._mask['padding_request_id'] - del self._position_ids['padding_request_id'] - - input_tokens = make_tensor_with_pad(input_tokens, - pad=0, - dtype=torch.long, - max_len=1, - device=self.device) - input_positions = make_tensor_with_pad(input_positions, - pad=0, - dtype=torch.long, - max_len=1, - device=self.device) - input_masks = torch.stack(input_masks, dim=0) - # why is this here, it has no effect? - context_lens = torch.tensor(context_lens, - dtype=torch.int, + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, device=self.device) - input_block_ids = torch.tensor(input_block_ids, - dtype=torch.long, - device=self.device) - return input_tokens, input_positions, input_masks, input_block_ids + return input_tokens, self._position_ids, self._mask + + def _update_position_ids(self) -> None: + """Updating the position ids of all sequences in a batch. Will be called in decoding phase""" - def _update_mask(self, request_id) -> None: - """Updating/extending the attention masks of a sequence in a SequenceGroup. Will be called in decoding phase""" + self._position_ids = self._position_ids[:, -1] + 1 + self._position_ids = self._position_ids.unsqueeze(-1) - assert self._mask[request_id] is not None - masks = self._mask[request_id] + def _update_mask(self) -> None: + """Updating/extending the attention masks of all sequences in a batch. Will be called in decoding phase""" - # expand batch dimension (batch size 1) during inference to use the same function for inference and warmup - is_decoding = False - if len(masks.shape) == 2: - masks = masks.unsqueeze(0) - is_decoding = True + assert self._mask is not None + masks = self._mask masks_new = [] for mask in masks: @@ -389,30 +228,23 @@ def _update_mask(self, request_id) -> None: ) masks_new.append(mask_new) - masks_new_stacked = torch.stack(masks_new, dim=0) - - # collaps batch dimension again for decoding phase (scheduler handles batch dimensions there) - if is_decoding: - masks_new_stacked = masks_new_stacked.squeeze(0) - - self._mask[request_id] = masks_new_stacked + self._mask = torch.stack(masks_new, dim=0) def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], finished_requests_ids: Optional[List[str]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - SamplingMetadata]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, input_masks, input_block_ids, + (input_tokens, input_positions, input_masks, seq_lens) = self._prepare_prompt(seq_group_metadata_list) else: - (input_tokens, input_positions, input_masks, - input_block_ids) = self._prepare_decode(seq_group_metadata_list) + (input_tokens, input_positions, + input_masks) = self._prepare_decode(seq_group_metadata_list) seq_lens = [] # Clean up generators from completed requests @@ -430,16 +262,14 @@ def prepare_input_tensors( self.device, self.pin_memory, self.generators) - return (input_tokens, input_positions, input_masks, input_block_ids, - sampling_metadata) + return (input_tokens, input_positions, input_masks, sampling_metadata) - #@torch.inference_mode() def execute_model( self, seq_group_metadata_list: List[SequenceGroupMetadata], finished_requests_ids: Optional[List[str]] = None, ) -> Optional[SamplerOutput]: - (input_tokens, input_positions, input_masks, input_block_ids, + (input_tokens, input_positions, input_masks, sampling_metadata) = self.prepare_input_tensors( seq_group_metadata_list, finished_requests_ids) t0 = time.time() @@ -447,7 +277,6 @@ def execute_model( input_ids=input_tokens, positions=input_positions, masks=input_masks, - input_block_ids=input_block_ids, seq_group_metadata_list=seq_group_metadata_list, ) @@ -473,16 +302,19 @@ def pad_input_ids( self, input_ids_list: List[torch.Tensor], min_pad_length: int = 0, - ) -> Tuple[torch.Tensor, MutableMapping[str, Any]]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: '''left side padding implemented as in fms.utils.generation.pad_input_id''' max_len = max([min_pad_length] + [seq.size(0) for seq in input_ids_list]) - padded_input_ids_list = [] mask_list = [] position_ids_list = [] for input_ids_i in input_ids_list: seq_len = input_ids_i.size(0) + if max_len > seq_len: + print( + f'[SENDNNModelRunner] INFO: Padding request of length {seq_len} tokens to {max_len} tokens.' + ) pads = torch.zeros(max_len - seq_len, dtype=torch.long, device=input_ids_i.device) @@ -509,9 +341,6 @@ def pad_input_ids( mask = (mask.unsqueeze(-1) == mask.unsqueeze(-2)).tril() mask = torch.where(mask.logical_not(), -torch.inf, 0.0) mask = mask.to(self.model.dtype) - padding_kwargs["mask"] = mask - position_ids = torch.stack(position_ids_list) - padding_kwargs["position_ids"] = position_ids - return input_ids, padding_kwargs + return input_ids, position_ids, mask diff --git a/vllm/worker/sendnn_worker.py b/vllm/worker/sendnn_worker.py index 534da0ef0..3eb66b1f5 100644 --- a/vllm/worker/sendnn_worker.py +++ b/vllm/worker/sendnn_worker.py @@ -22,7 +22,6 @@ init_distributed_environment) from torch_sendnn import torch_sendnn -from vllm.utils import make_tensor_with_pad class SENDNNWorker(LoraNotSupportedWorkerBase): @@ -227,11 +226,6 @@ def _warmup_sendnn_fixed_size(self, prompt_len, num_decode_tokens, num_decode_tokens, batch_size, extra_kwargs) - # delete warmup request_id entries - del self.model_runner._position_ids['warmup_request_id'] - del self.model_runner._mask['warmup_request_id'] - - # 3. done? warmup_end_t = time.time() warmup_total_t = warmup_end_t - warmup_start_t print("[SENDNNWorker] ... warmup finished.") @@ -244,66 +238,32 @@ def _warmup_model_forward_pass(self, warmup_tokens_tensor, num_decode_tokens, batch_size, extra_kwargs): # padding warmup tokens to obtain the corresponding postition ids and mask - warmup_tokens_pad, warmup_padding_kwargs = self.model_runner.pad_input_ids( + warmup_tokens_pad, self.model_runner._position_ids, self.model_runner._mask = self.model_runner.pad_input_ids( warmup_tokens_tensor, min_pad_length=prompt_len) - # set padding position ids in position_ids dict for a dummy warmup request_id = 'warmup_request_id' - self.model_runner._position_ids[ - 'warmup_request_id'] = warmup_padding_kwargs[ - 'position_ids'].tolist() - # set padding attention mask in mask dict for a dummy warmup request_id = 'warmup_request_id' - self.model_runner._mask['warmup_request_id'] = warmup_padding_kwargs[ - 'mask'] - - warmup_positions_pad = make_tensor_with_pad( - self.model_runner._position_ids['warmup_request_id'], - pad=0, - dtype=torch.long, - max_len=prompt_len, - device=self.model_runner.device) - + # prefill logits, past_key_value_states = self.model_runner.model.model( warmup_tokens_pad, - position_ids=warmup_positions_pad, - mask=self.model_runner._mask['warmup_request_id'], + position_ids=self.model_runner._position_ids, + mask=self.model_runner._mask, past_key_value_states=None, use_cache=True, only_last_token=True, **extra_kwargs) + # decoding for i in range(num_decode_tokens - 1): # sampling next input token from vocab without bos and eos tokens - decode_tokens_pad = valid_token_ids_tensor[torch.randint( + decode_tokens = valid_token_ids_tensor[torch.randint( 0, len(valid_token_ids_tensor), (batch_size, 1))] - # update mask - self.model_runner._update_mask(request_id='warmup_request_id') - # update position ids - position_ids = [] - position_ids_input = [] - for i in range( - len(self.model_runner._position_ids['warmup_request_id'])): - # take last position id and increment it by 1 - new_position_id = self.model_runner._position_ids[ - 'warmup_request_id'][i][-1] + 1 - position_ids_input.append([new_position_id]) - # append new position id to existing list of position ids - position_ids.append( - self.model_runner._position_ids['warmup_request_id'][i] + - [new_position_id]) - # store updated list position ids - self.model_runner._position_ids['warmup_request_id'] = position_ids - - decode_positions_pad = make_tensor_with_pad( - position_ids_input, - pad=0, - dtype=torch.long, - max_len=1, - device=self.model_runner.device) + # update mask and position_ids + self.model_runner._update_mask() + self.model_runner._update_position_ids() logits, past_key_value_states = self.model_runner.model.model( - decode_tokens_pad, - position_ids=decode_positions_pad, - mask=self.model_runner._mask['warmup_request_id'], + decode_tokens, + position_ids=self.model_runner._position_ids, + mask=self.model_runner._mask, past_key_value_states=past_key_value_states, use_cache=True, only_last_token=True,