Skip to content

Commit

Permalink
Ysc pluggable scheduler (#78)
Browse files Browse the repository at this point in the history
This PR moves the (spyre specific) scheduler class into the plugin repo
`vllm-spyre`.
it goes along with this
->[PR](https://github.com/IBM/vllm-spyre/pull/4)<- in the `vllm-spyre`
repository

Changes:
- `vllm/config.py`: introducing new variable `scheduler_cls` for class
`SchedulerConfig` (can be a path (str) or a class directly) and removing
spyre specific code to 'load warmup shapes and sort by "speed"' (moved
to `vllm-spyre`)
- `vllm/engine/llm_engine.py`: importing `Scheduler` based on
`vllm_config.scheduler_config.scheduler_cls`
- `vllm/core/scheduler.py`: removing all spyre specific code from the
scheduler logic
- `vllm/envs.py`: removing all spyre related env variables (moved to
`vllm-spyre`)
- remove the variable: `SchedulerConfig.spyre_warmup_shapes` in
`vllm/vllm.config.py`

---------

Signed-off-by: Yannick Schnider <[email protected]>
  • Loading branch information
yannicks1 authored Feb 20, 2025
1 parent caad376 commit 8c3c29f
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 139 deletions.
39 changes: 3 additions & 36 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import enum
import hashlib
import json
import operator
import sys
import warnings
from contextlib import contextmanager
Expand Down Expand Up @@ -1428,6 +1427,9 @@ class SchedulerConfig:

chunked_prefill_enabled: bool = field(init=False)

# scheduler class or path
scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler"

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
Expand Down Expand Up @@ -1486,41 +1488,6 @@ def __post_init__(self) -> None:
self.max_num_batched_tokens)

self.chunked_prefill_enabled = self.enable_chunked_prefill
from vllm.platforms import current_platform
self.spyre_scheduling_enabled = current_platform.get_device_name(
) == "spyre"
if self.spyre_scheduling_enabled:
# load warmup shapes and sort by "speed"
wup_prompt_lens = envs.VLLM_SPYRE_WARMUP_PROMPT_LENS or []
wup_batch_sizes = envs.VLLM_SPYRE_WARMUP_BATCH_SIZES or []
if len(wup_prompt_lens) != len(wup_batch_sizes):
raise RuntimeError(
"The lists in VLLM_SPYRE_WARMUP_PROMPT_LENS and "
"VLLM_SPYRE_WARMUP_BATCH_SIZES must have equal length")
if self.runner_type == "pooling":
wup_new_tokens = [0] * len(wup_prompt_lens)
else:
wup_new_tokens = envs.VLLM_SPYRE_WARMUP_NEW_TOKENS or []
if len(wup_new_tokens) != len(wup_prompt_lens):
raise RuntimeError(
"The lists in VLLM_SPYRE_WARMUP_PROMPT_LENS and "
"VLLM_SPYRE_WARMUP_NEW_TOKENS must have equal length")

print("[SchedulerConfig] VLLM_SPYRE_WARMUP_PROMPT_LENS =",
wup_prompt_lens)
print("[SchedulerConfig] VLLM_SPYRE_WARMUP_NEW_TOKENS =",
wup_new_tokens)
print("[SchedulerConfig] VLLM_SPYRE_WARMUP_BATCH_SIZES =",
wup_batch_sizes)

self.spyre_warmup_shapes = tuple(
sorted([{
'prompt_length': pl,
'new_tokens': nt,
'batch_size': bs
} for pl, nt, bs in zip(wup_prompt_lens, wup_new_tokens,
wup_batch_sizes)],
key=operator.itemgetter('batch_size', 'prompt_length')))
self._verify_args()

def _verify_args(self) -> None:
Expand Down
77 changes: 4 additions & 73 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,10 +919,6 @@ def _schedule_prefills(
ignored_seq_groups: List[SequenceGroup] = []
seq_groups: List[ScheduledSequenceGroup] = []

if self.scheduler_config.spyre_scheduling_enabled:
applicable_spyre_warmup_shapes = list(
self.scheduler_config.spyre_warmup_shapes)

waiting_queue = self.waiting

leftover_waiting_sequences: Deque[SequenceGroup] = deque()
Expand Down Expand Up @@ -1003,54 +999,6 @@ def _schedule_prefills(
):
break

# check if current request can be scheduled based on the applicable
# spyre warmup shapes
if self.scheduler_config.spyre_scheduling_enabled:
max_tokens = 0
if seq_group.sampling_params is not None and\
seq_group.sampling_params.max_tokens is not None:
max_tokens = seq_group.sampling_params.max_tokens
updated_spyre_warmup_shapes = [
shape for shape in applicable_spyre_warmup_shapes
if num_new_tokens <= shape['prompt_length']
and max_tokens <= shape['new_tokens']
and len(seq_groups) < shape['batch_size']
]
if not updated_spyre_warmup_shapes:
if not seq_groups:
# request was tested against all spyre warmup shapes:
# request cannot be processed
if (seq_group.sampling_params is not None
and seq_group.sampling_params.max_tokens
is not None):
logger.warning(
"No applicable warmup shape exists for "
"combination of prompt length (%d tokens) "
"and maximum number of output tokens to be "
"generated (%d tokens)", num_new_tokens,
seq_group.sampling_params.max_tokens)
else:
logger.warning(
"No applicable warmup shape exists for "
"combination of prompt length (%d tokens) "
"and undefined maximum number of output "
"tokens", num_new_tokens)
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
waiting_queue.popleft()
continue
else:
# request was only tested against spyre warmup shapes
# that remain after processing previous requests in
# waiting queue: request will be evaluated again in
# a future scheduling step
leftover_waiting_sequences.appendleft(seq_group)
waiting_queue.popleft()
continue
else:
applicable_spyre_warmup_shapes = updated_spyre_warmup_shapes

# Can schedule this request.
if curr_loras is not None and lora_int_id > 0:
curr_loras.add(lora_int_id)
Expand Down Expand Up @@ -1084,15 +1032,6 @@ def _schedule_prefills(
)
budget.add_num_seqs(seq_group.request_id, num_new_seqs)

# Check if number of scheduled requests has reached the maximum
# batch size of the applicable warmup shapes
if self.scheduler_config.spyre_scheduling_enabled and len(
seq_groups) >= max([
shape['batch_size']
for shape in applicable_spyre_warmup_shapes
]):
break

# Queue requests that couldn't be scheduled.
waiting_queue.extendleft(leftover_waiting_sequences)
if len(seq_groups) > 0:
Expand Down Expand Up @@ -1130,11 +1069,8 @@ def _schedule_default(self) -> SchedulerOutputs:
running_scheduled = SchedulerRunningOutputs.create_empty()
swapped_in = SchedulerSwappedInOutputs.create_empty()

# Schedule new prefills only when no requests have been swapped
# and all previous decodes have completed.
if not self.swapped and (
not self.scheduler_config.spyre_scheduling_enabled
or not self.running):
# If any requests are swapped, prioritized swapped requests.
if not self.swapped:
prefills = self._schedule_prefills(budget,
curr_loras,
enable_chunking=False)
Expand Down Expand Up @@ -1335,13 +1271,8 @@ def _can_append_slots(self, seq_group: SequenceGroup,
# chunked-prefill are enabled together.
assert self.scheduler_config.is_multi_step and enable_chunking

if self.scheduler_config.spyre_scheduling_enabled:
# heuristic below doesn't make sense when using very large
# blocks
return True
else:
return self.block_manager.can_append_slots(
seq_group=seq_group, num_lookahead_slots=num_lookahead_slots)
return self.block_manager.can_append_slots(
seq_group=seq_group, num_lookahead_slots=num_lookahead_slots)

def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
# async_output_proc is allowed only when we have a single sequence
Expand Down
12 changes: 9 additions & 3 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig,
VllmConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs)
from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase, Stats
from vllm.engine.output_processor.interfaces import (
Expand Down Expand Up @@ -56,7 +55,8 @@
BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
from vllm.utils import (Counter, Device, deprecate_kwargs,
resolve_obj_by_qualname, weak_bind)
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)
Expand Down Expand Up @@ -344,6 +344,12 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.

if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str):
Scheduler = resolve_obj_by_qualname(
self.vllm_config.scheduler_config.scheduler_cls)
else:
Scheduler = self.vllm_config.scheduler_config.scheduler_cls
self.scheduler = [
Scheduler(
self.scheduler_config, self.cache_config, self.lora_config,
Expand Down
27 changes: 0 additions & 27 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@
VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_SPYRE_WARMUP_PROMPT_LENS: Optional[List[int]] = None
VLLM_SPYRE_WARMUP_NEW_TOKENS: Optional[List[int]] = None
VLLM_SPYRE_WARMUP_BATCH_SIZES: Optional[List[int]] = None
VLLM_DISABLED_KERNELS: List[str] = []
VLLM_USE_V1: bool = False
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
Expand Down Expand Up @@ -470,30 +467,6 @@ def get_default_config_root():
lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")),
"VLLM_DISABLE_COMPILE_CACHE":
lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))),

# Defines the prompt lengths the Spyre accelerator should be prepared
# for, formatted as comma separated list.
"VLLM_SPYRE_WARMUP_PROMPT_LENS":
lambda: [
int(p) for p in os.getenv(key='VLLM_SPYRE_WARMUP_PROMPT_LENS',
default='64').split(',')
],

# Defines the max output tokens the Spyre accelerator should be prepared
# for, formatted as comma separated list.
"VLLM_SPYRE_WARMUP_NEW_TOKENS":
lambda: [
int(d) for d in os.getenv(key='VLLM_SPYRE_WARMUP_NEW_TOKENS',
default='20').split(',')
],

# Defines the batch sizes the Spyre accelerator should be prepared
# for, formatted as comma separated list.
"VLLM_SPYRE_WARMUP_BATCH_SIZES":
lambda: [
int(b) for b in os.getenv(key='VLLM_SPYRE_WARMUP_BATCH_SIZES',
default='1').split(',')
],
}

# end-env-vars-definition
Expand Down

0 comments on commit 8c3c29f

Please sign in to comment.