Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Multi-sequence broken #11898

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Conversation

andylolu2
Copy link

@andylolu2 andylolu2 commented Jan 9, 2025

Fixes the bugs introduced in #9569

  • SequenceGroup does not necessarily contain only one sequence (e.g. when n > 1), so many of the optimisations don't make sense.
  • Currently the seed is duplicated across all completions, so when we have n > 1 with seed set, all completions give the same output.
  • Currently only the first sequence in a ParallelSampleSequenceGroup yields responses. But once the first sequence finishes it won't receive new chunks. This means responses from other sequences are not sent when the first sequence terminates first.

Copy link

github-actions bot commented Jan 9, 2025

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@andylolu2
Copy link
Author

@youkaichao

vllm/sequence.py Outdated
Comment on lines 821 to 831
n = self.sampling_params.n
assert isinstance(n, int)
if n > self.num_seqs():
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `n` sequences
# running.
return n
# At sampling stages, return the number of actual sequences
# that are not finished yet.
return self.num_seqs() - self.num_finished_seqs()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when will we hit this? I think the engine will only see single-sequence request

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you construct the output when n > 1 you access the "master group".

Copy link
Author

@andylolu2 andylolu2 Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, you construct the RequestOutput with multiple sequences here:

return cls.from_seq_group(assembled_seq_group, use_cache,

Then call master_seq_group.is_finished() here:

finished = seq_group.is_finished()

Which currently already becomes True when the first sequence terminates (regardless of whether the other sequences has terminated)

Comment on lines +1422 to +1425
params = copy.deepcopy(original_params)
params.n = 1
if params.seed is not None:
params.seed += i
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part makes sense to me.

@youkaichao
Copy link
Member

@andylolu2 thanks for the fix! can you add a test case for n > 1 and seed to make sure they are different?

@andylolu2 andylolu2 force-pushed the main branch 2 times, most recently from 0be80f4 to 7c31b9c Compare January 12, 2025 21:58
@andylolu2
Copy link
Author

andylolu2 commented Jan 12, 2025

@youkaichao I added new asserts in the current tests to ensure each sample in the same parallel-sampling group gives different results.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants