Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Optimize handling of sampling metadata and req_ids list #13244

Merged
merged 17 commits into from
Feb 18, 2025

Conversation

njhill
Copy link
Member

@njhill njhill commented Feb 13, 2025

  • Move the current SamplingMetadata object to a field in the persistent batch, updated only when the batch changes rather than constructed every step
  • Keep input_batch.req_ids sized to the number of requests in the batch, so that anywhere that iterates over it doesn't need to slice (copy) the list or keep track of the separate request count. It is still updated in-place

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Feb 13, 2025
Copy link

mergify bot commented Feb 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @njhill.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 14, 2025
- Move SamplingMetadata to a field in the persistent batch, updated only when the batch changes rather than constructed every step
- Keep input_batch.req_ids sized to the number of requests in the batch, so that anywhere that iterates over it doesn't need to slice (copy) the list or keep track of the separate request count. It is still updated in-place

Signed-off-by: Nick Hill <[email protected]>
@njhill njhill force-pushed the sampler-streamline branch from 2bcf20f to 7d6ee8f Compare February 14, 2025 16:27
@mergify mergify bot removed the needs-rebase label Feb 14, 2025
@njhill
Copy link
Member Author

njhill commented Feb 14, 2025

@WoosukKwon this is the first step, I am working on follow-on simplification for the penalty parameters, etc.

@WoosukKwon WoosukKwon self-assigned this Feb 14, 2025
@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 14, 2025
@njhill
Copy link
Member Author

njhill commented Feb 14, 2025

@WoosukKwon apologies, I am looking into the test failure.

@njhill
Copy link
Member Author

njhill commented Feb 14, 2025

@WoosukKwon the test failure should be fixed now... the shared apply penalties code was doing in-place unsqueezes on the sampling penalty tensors - which I think is a bad thing to do but didn't cause a problem before because we were passing new slices every step.

Copy link

mergify bot commented Feb 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @njhill.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 14, 2025
# Conflicts:
#	vllm/v1/worker/gpu_input_batch.py
@mergify mergify bot removed the needs-rebase label Feb 15, 2025
@WoosukKwon
Copy link
Collaborator

Hi @njhill, do you mind if we merge #12193 first and review this PR? I'd like to prioritize the spec decode PR as it already got rebased many many times.

@njhill
Copy link
Member Author

njhill commented Feb 15, 2025

@WoosukKwon that's fine with me.

…streamline

Signed-off-by: Nick Hill <[email protected]>

# Conflicts:
#	tests/v1/worker/test_gpu_input_batch.py
#	vllm/v1/sample/sampler.py
Signed-off-by: Nick Hill <[email protected]>
Copy link

mergify bot commented Feb 16, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @njhill.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@WoosukKwon
Copy link
Collaborator

@njhill Sorry for the delay. I will review this PR once it's rebased.

Signed-off-by: Nick Hill <[email protected]>

# Conflicts:
#	tests/v1/sample/test_sampler.py
#	tests/v1/worker/test_gpu_input_batch.py
#	vllm/v1/worker/gpu_input_batch.py
#	vllm/v1/worker/gpu_model_runner.py
@mergify mergify bot removed the needs-rebase label Feb 18, 2025
vllm/v1/request.py Outdated Show resolved Hide resolved
vllm/v1/core/scheduler.py Outdated Show resolved Hide resolved
@njhill
Copy link
Member Author

njhill commented Feb 18, 2025

@WoosukKwon I have now rebased. #13360 partially overlaps with this (e,g. I simplified some of the min_tokens handling in this one but have refactored completely in the other one based on the new abstraction). But I think it would be fine to get this in first and I can rebase the other one if you're ok with that.

Signed-off-by: Nick Hill <[email protected]>
@WoosukKwon
Copy link
Collaborator

@njhill I'm not sure it's worthwhile to change from [] to ().
I did a microbenchmark:

N = 1024
x = []
# List
start = time.perf_counter()
for i in range(N):
    x.append([])
end = time.perf_counter()
print(f"list: {(end - start) * 1000:.3f} ms")

y = []
# Tuple
start = time.perf_counter()
for i in range(N):
    y.append(())
end = time.perf_counter()
print(f"tuple: {(end - start) * 1000:.3f} ms")

I find that adding 1024 (maximum number of requests in the batch) empty lists only takes 80-90 us. While using tuple reduces this time to 30-40 us, I think the 50 us gap (in the worst case) cannot justify the extra complexity here. When the batch size is 32, the gap becomes even smaller (7 us vs 2 us). WDYT?

@njhill
Copy link
Member Author

njhill commented Feb 18, 2025

@WoosukKwon I agree it's not worth any extra complexity. Just might as well use () where it doesn't otherwise make any difference to the code. Let me check and revert where such changes were made..

@WoosukKwon
Copy link
Collaborator

@njhill I think changing List to Sequence itself is increasing complexity? After that, we need to consider whether it's a tuple or list. I'd prefer to keep using List and [] if the performance is the only concern.

@njhill
Copy link
Member Author

njhill commented Feb 18, 2025

@WoosukKwon sure, let me revert those too. I think mostly we don't need to consider the tuple/list difference because these are args or fields that would be considered read-only.

@njhill
Copy link
Member Author

njhill commented Feb 18, 2025

@WoosukKwon I need to fix up some of the gpu_model_runner tests, but I'll wait for your first review to make sure you are good with the changes overall before spending time on that.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing. Looks much cleaner! 😄

vllm/v1/worker/gpu_model_runner.py Show resolved Hide resolved
Comment on lines +198 to +200
del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids[:num_scheduled_spec_tokens])
request.spec_token_ids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this change for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It avoids creating a new list, just trims the existing one down to num_scheduled_spec_tokens, since any later spec token ids are essentially discarded anyhow.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it! Maybe worth a comment.

vllm/v1/sample/metadata.py Outdated Show resolved Hide resolved
vllm/v1/worker/gpu_input_batch.py Show resolved Hide resolved
vllm/v1/worker/gpu_input_batch.py Outdated Show resolved Hide resolved
vllm/v1/worker/gpu_input_batch.py Outdated Show resolved Hide resolved
vllm/v1/worker/gpu_input_batch.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Very nice simplification!

Signed-off-by: Nick Hill <[email protected]>
@njhill njhill merged commit 30172b4 into vllm-project:main Feb 18, 2025
44 checks passed
@njhill njhill deleted the sampler-streamline branch February 18, 2025 20:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants