Skip to content

Commit

Permalink
Add batched prefill via VLLM_SCHED_PREFILL_COUNT
Browse files Browse the repository at this point in the history
To ensure we we don't run prefills repeatedly during decode, provide a
mechanism to queue up a certain number of prefills before executing.
VLLM_SCHED_PREFILL_COUNT will be the minimum batch count to specify before
executing.  One caveat, the --scheduler-delay-factor should be used to
enforce a longer prefill scheduling value.  This will be set to the value
in VLLM_SCHED_PREFILL_COUNT, if not explicitly provided.  The need for this exists
because an uneven number of prefills can lead to the queue never reaching the
VLLM_SCHED_PREFILL_COUNT.  Causing the server to hang
  • Loading branch information
dllehr-amd authored and valarLip committed Aug 11, 2024
1 parent ce8a86e commit 8148b54
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
ARTIFICIAL_PREEMPTION_PROB = 0.5
ARTIFICIAL_PREEMPTION_MAX_CNT = 500

VLLM_SCHED_PREFILL_COUNT = int(
os.getenv("VLLM_SCHED_PREFILL_COUNT", 0)) # noqa


class PreemptionMode(enum.Enum):
"""Preemption modes.
Expand Down Expand Up @@ -263,7 +266,15 @@ def __init__(
# simple and NOT fair. It can lead to starvation of some
# LoRAs. This should be improved in the future.
self.lora_config = lora_config

self.prefill_timeout = 0

# slightly hackey, but if you specify prefill batch count, the delay factor
# needs to exist, otherwise we will always skip. Default will be equal to
# VLLM_SCHED_PREFILL_COUNT, as they should be roughly the same.
# Recommend setting with --scheduler-delay-factor and experimenting
# On command line
if VLLM_SCHED_PREFILL_COUNT > 0 and self.scheduler_config.delay_factor == 0:
self.scheduler_config.delay_factor = VLLM_SCHED_PREFILL_COUNT
version = "v1"
if self.scheduler_config.use_v2_block_manager:
version = "v2"
Expand Down Expand Up @@ -644,7 +655,8 @@ def _schedule_prefills(
waiting_queue = deque([s for s in waiting_queue])

leftover_waiting_sequences: Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and waiting_queue:

while (VLLM_SCHED_PREFILL_COUNT <= len(waiting_queue) or self._passed_delay(time.time())) and waiting_queue:
seq_group = waiting_queue[0]

waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
Expand Down Expand Up @@ -719,7 +731,6 @@ def _schedule_prefills(
waiting_queue.extendleft(leftover_waiting_sequences)
if len(seq_groups) > 0:
self.prev_prompt = True

return waiting_queue, SchedulerPrefillOutputs(
seq_groups=seq_groups,
ignored_seq_groups=ignored_seq_groups,
Expand Down

0 comments on commit 8148b54

Please sign in to comment.