diff --git a/vllm/config.py b/vllm/config.py index e7fb8e2d0..150192599 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3,7 +3,6 @@ import enum import hashlib import json -import operator import sys import warnings from contextlib import contextmanager @@ -1428,6 +1427,9 @@ class SchedulerConfig: chunked_prefill_enabled: bool = field(init=False) + # scheduler class or path + scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -1486,41 +1488,6 @@ def __post_init__(self) -> None: self.max_num_batched_tokens) self.chunked_prefill_enabled = self.enable_chunked_prefill - from vllm.platforms import current_platform - self.spyre_scheduling_enabled = current_platform.get_device_name( - ) == "spyre" - if self.spyre_scheduling_enabled: - # load warmup shapes and sort by "speed" - wup_prompt_lens = envs.VLLM_SPYRE_WARMUP_PROMPT_LENS or [] - wup_batch_sizes = envs.VLLM_SPYRE_WARMUP_BATCH_SIZES or [] - if len(wup_prompt_lens) != len(wup_batch_sizes): - raise RuntimeError( - "The lists in VLLM_SPYRE_WARMUP_PROMPT_LENS and " - "VLLM_SPYRE_WARMUP_BATCH_SIZES must have equal length") - if self.runner_type == "pooling": - wup_new_tokens = [0] * len(wup_prompt_lens) - else: - wup_new_tokens = envs.VLLM_SPYRE_WARMUP_NEW_TOKENS or [] - if len(wup_new_tokens) != len(wup_prompt_lens): - raise RuntimeError( - "The lists in VLLM_SPYRE_WARMUP_PROMPT_LENS and " - "VLLM_SPYRE_WARMUP_NEW_TOKENS must have equal length") - - print("[SchedulerConfig] VLLM_SPYRE_WARMUP_PROMPT_LENS =", - wup_prompt_lens) - print("[SchedulerConfig] VLLM_SPYRE_WARMUP_NEW_TOKENS =", - wup_new_tokens) - print("[SchedulerConfig] VLLM_SPYRE_WARMUP_BATCH_SIZES =", - wup_batch_sizes) - - self.spyre_warmup_shapes = tuple( - sorted([{ - 'prompt_length': pl, - 'new_tokens': nt, - 'batch_size': bs - } for pl, nt, bs in zip(wup_prompt_lens, wup_new_tokens, - wup_batch_sizes)], - key=operator.itemgetter('batch_size', 'prompt_length'))) self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 0834706c7..b3d396f9c 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -919,10 +919,6 @@ def _schedule_prefills( ignored_seq_groups: List[SequenceGroup] = [] seq_groups: List[ScheduledSequenceGroup] = [] - if self.scheduler_config.spyre_scheduling_enabled: - applicable_spyre_warmup_shapes = list( - self.scheduler_config.spyre_warmup_shapes) - waiting_queue = self.waiting leftover_waiting_sequences: Deque[SequenceGroup] = deque() @@ -1003,54 +999,6 @@ def _schedule_prefills( ): break - # check if current request can be scheduled based on the applicable - # spyre warmup shapes - if self.scheduler_config.spyre_scheduling_enabled: - max_tokens = 0 - if seq_group.sampling_params is not None and\ - seq_group.sampling_params.max_tokens is not None: - max_tokens = seq_group.sampling_params.max_tokens - updated_spyre_warmup_shapes = [ - shape for shape in applicable_spyre_warmup_shapes - if num_new_tokens <= shape['prompt_length'] - and max_tokens <= shape['new_tokens'] - and len(seq_groups) < shape['batch_size'] - ] - if not updated_spyre_warmup_shapes: - if not seq_groups: - # request was tested against all spyre warmup shapes: - # request cannot be processed - if (seq_group.sampling_params is not None - and seq_group.sampling_params.max_tokens - is not None): - logger.warning( - "No applicable warmup shape exists for " - "combination of prompt length (%d tokens) " - "and maximum number of output tokens to be " - "generated (%d tokens)", num_new_tokens, - seq_group.sampling_params.max_tokens) - else: - logger.warning( - "No applicable warmup shape exists for " - "combination of prompt length (%d tokens) " - "and undefined maximum number of output " - "tokens", num_new_tokens) - for seq in waiting_seqs: - seq.status = SequenceStatus.FINISHED_IGNORED - ignored_seq_groups.append(seq_group) - waiting_queue.popleft() - continue - else: - # request was only tested against spyre warmup shapes - # that remain after processing previous requests in - # waiting queue: request will be evaluated again in - # a future scheduling step - leftover_waiting_sequences.appendleft(seq_group) - waiting_queue.popleft() - continue - else: - applicable_spyre_warmup_shapes = updated_spyre_warmup_shapes - # Can schedule this request. if curr_loras is not None and lora_int_id > 0: curr_loras.add(lora_int_id) @@ -1084,15 +1032,6 @@ def _schedule_prefills( ) budget.add_num_seqs(seq_group.request_id, num_new_seqs) - # Check if number of scheduled requests has reached the maximum - # batch size of the applicable warmup shapes - if self.scheduler_config.spyre_scheduling_enabled and len( - seq_groups) >= max([ - shape['batch_size'] - for shape in applicable_spyre_warmup_shapes - ]): - break - # Queue requests that couldn't be scheduled. waiting_queue.extendleft(leftover_waiting_sequences) if len(seq_groups) > 0: @@ -1130,11 +1069,8 @@ def _schedule_default(self) -> SchedulerOutputs: running_scheduled = SchedulerRunningOutputs.create_empty() swapped_in = SchedulerSwappedInOutputs.create_empty() - # Schedule new prefills only when no requests have been swapped - # and all previous decodes have completed. - if not self.swapped and ( - not self.scheduler_config.spyre_scheduling_enabled - or not self.running): + # If any requests are swapped, prioritized swapped requests. + if not self.swapped: prefills = self._schedule_prefills(budget, curr_loras, enable_chunking=False) @@ -1335,13 +1271,8 @@ def _can_append_slots(self, seq_group: SequenceGroup, # chunked-prefill are enabled together. assert self.scheduler_config.is_multi_step and enable_chunking - if self.scheduler_config.spyre_scheduling_enabled: - # heuristic below doesn't make sense when using very large - # blocks - return True - else: - return self.block_manager.can_append_slots( - seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) + return self.block_manager.can_append_slots( + seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: # async_output_proc is allowed only when we have a single sequence diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 88c21f9a6..695d6fa0a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -17,8 +17,7 @@ from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, SchedulerConfig, VllmConfig) -from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, - SchedulerOutputs) +from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics_types import StatLoggerBase, Stats from vllm.engine.output_processor.interfaces import ( @@ -56,7 +55,8 @@ BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind +from vllm.utils import (Counter, Device, deprecate_kwargs, + resolve_obj_by_qualname, weak_bind) from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -344,6 +344,12 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: # Create the scheduler. # NOTE: the cache_config here have been updated with the numbers of # GPU and CPU blocks, which are profiled in the distributed executor. + + if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str): + Scheduler = resolve_obj_by_qualname( + self.vllm_config.scheduler_config.scheduler_cls) + else: + Scheduler = self.vllm_config.scheduler_config.scheduler_cls self.scheduler = [ Scheduler( self.scheduler_config, self.cache_config, self.lora_config, diff --git a/vllm/envs.py b/vllm/envs.py index 5291a4616..b7b597ea1 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -67,9 +67,6 @@ VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False - VLLM_SPYRE_WARMUP_PROMPT_LENS: Optional[List[int]] = None - VLLM_SPYRE_WARMUP_NEW_TOKENS: Optional[List[int]] = None - VLLM_SPYRE_WARMUP_BATCH_SIZES: Optional[List[int]] = None VLLM_DISABLED_KERNELS: List[str] = [] VLLM_USE_V1: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = True @@ -470,30 +467,6 @@ def get_default_config_root(): lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")), "VLLM_DISABLE_COMPILE_CACHE": lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))), - - # Defines the prompt lengths the Spyre accelerator should be prepared - # for, formatted as comma separated list. - "VLLM_SPYRE_WARMUP_PROMPT_LENS": - lambda: [ - int(p) for p in os.getenv(key='VLLM_SPYRE_WARMUP_PROMPT_LENS', - default='64').split(',') - ], - - # Defines the max output tokens the Spyre accelerator should be prepared - # for, formatted as comma separated list. - "VLLM_SPYRE_WARMUP_NEW_TOKENS": - lambda: [ - int(d) for d in os.getenv(key='VLLM_SPYRE_WARMUP_NEW_TOKENS', - default='20').split(',') - ], - - # Defines the batch sizes the Spyre accelerator should be prepared - # for, formatted as comma separated list. - "VLLM_SPYRE_WARMUP_BATCH_SIZES": - lambda: [ - int(b) for b in os.getenv(key='VLLM_SPYRE_WARMUP_BATCH_SIZES', - default='1').split(',') - ], } # end-env-vars-definition