Skip to content

Commit

Permalink
sequence level processing -> batch level processing (#62)
Browse files Browse the repository at this point in the history
### sequence level processing -> batch level processing
In this PR the code for preparing the input tensors for the AIU is
completely rewritten based on the assumption that we have to finish the
current decoding on AIU before doing another prefill.

Changes:
*
[rewriting](https://github.ibm.com/ai-foundation/vllm/pull/62/commits/cea122c220b18e3de3dce95faa5e03fe3efe0835)
`sendnn_model_runner.py`, `sendnn_worker.py` and `sendnn.py` based on
the above constraint.
*
[removing](https://github.ibm.com/ai-foundation/vllm/pull/62/commits/6869231d83734d3c03ffd15bc6754c1857d063cc)
class variable `self._padded_batch_size` since other solution
implemented
*
[removing](https://github.ibm.com/ai-foundation/vllm/pull/62/commits/ff9ebf6923fd9ac6c99e64dfffc7763f6c194399)
the unused `input_block_ids` since AIU does not support paged attention
yet.
*
[removing](https://github.ibm.com/ai-foundation/vllm/pull/62/commits/a6d63899bf3d9fae59edde414b8bd2a3c56bc8c7)
some unused function arguments in model loading
*
[removing](https://github.ibm.com/ai-foundation/vllm/pull/62/commits/4527300ee9be4dd1fb76007fb6e0862b97d51676)
unused function _get_model_architecture() and global variable
`_SENDNN_SUPPORTED_MODELS`

The code has been tested in client/server mode for the `llama 194m` and
`granite 3b` on `AIU` and `CPU`.
  • Loading branch information
yannicks1 authored and GitHub Enterprise committed Nov 6, 2024
1 parent fc794f8 commit 233b4a6
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 334 deletions.
7 changes: 5 additions & 2 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,10 +895,13 @@ def _schedule_prefills(
logger.warning(
"Input prompt (%d tokens) is too long"
" and exceeds limit of %d", num_new_tokens, prompt_limit)
if cond2 := self.scheduler_config.spyre_scheduling_enabled and num_new_tokens > max(self.scheduler_config.spyre_warmup_shapes.keys()):
if cond2 := self.scheduler_config.spyre_scheduling_enabled and num_new_tokens > max(
self.scheduler_config.spyre_warmup_shapes.keys()):
logger.warning(
"Input prompt (%d tokens) is too long"
" and exceeds maximum padding length of %d", num_new_tokens, max(self.scheduler_config.spyre_warmup_shapes.keys()))
" and exceeds maximum padding length of %d",
num_new_tokens,
max(self.scheduler_config.spyre_warmup_shapes.keys()))
if cond1 or cond2:
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
Expand Down
32 changes: 5 additions & 27 deletions vllm/model_executor/model_loader/sendnn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Utilities for selecting and loading SENDNN models."""
import sys
from typing import Dict, Optional, Tuple, List
from typing import Optional, List

import torch
import torch.nn as nn
Expand All @@ -9,7 +9,7 @@
from fms.models import get_model

import vllm.envs as envs
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, DeviceConfig
from vllm.config import ModelConfig, ParallelConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
Expand All @@ -31,14 +31,6 @@
print("WARNING: Disabled: dynamo_tracer")
pass

# Models supported by SENDNN.
_SENDNN_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
# "LlamaForCausalLM": ("transformers_neuronx.llama.model",
# "LlamaForSampling", "LlamaForCausalLM"),
# "MistralForCausalLM": ("transformers_neuronx.mistral.model",
# "MistralForSampling", "MistralForCausalLM")
}

BACKEND_LIST = ['sendnn_decoder', 'inductor']

# used as baseline the following code for llama 7b
Expand Down Expand Up @@ -71,7 +63,6 @@ def forward(
input_ids: torch.Tensor,
positions: torch.Tensor,
masks: torch.Tensor,
input_block_ids: torch.Tensor,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> torch.Tensor:

Expand All @@ -97,8 +88,8 @@ def forward(
logits, past_key_value_states = output
self.past_key_value_states = past_key_value_states

# removing batch padding again to compute logits
batch_size = input_block_ids.shape[0]
# removing batch padding sequences to compute logits
batch_size = input_ids.shape[0]

logits = logits[:batch_size - self.num_padded_sequences]

Expand Down Expand Up @@ -201,21 +192,8 @@ def load_weights(self, model_config: ModelConfig, max_prompt_length: int,
backend=envs.VLLM_SPYRE_DYNAMO_BACKEND)


def _get_model_architecture(config: PretrainedConfig) -> str:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _SENDNN_SUPPORTED_MODELS:
return arch
raise ValueError(
f"Model architectures {architectures} are not supported on SENDNN "
f"for now. Supported architectures: "
f"{list(_SENDNN_SUPPORTED_MODELS.keys())}")


def get_sendnn_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig, max_prompt_length,
parallel_config: ParallelConfig, max_prompt_length,
max_decode_length) -> nn.Module:

# Create a model instance.
Expand Down
Loading

0 comments on commit 233b4a6

Please sign in to comment.