-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Parallel sampling support #106
Conversation
This reverts commit 5382004.
|
||
def get_prompt_sequence_id(request_id: RequestId) -> SequenceId: | ||
return SequenceId(request_id, PROMPT_SEQEUNCE_INDEX) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also update the assert at https://github.com/masahi/mlc-llm/blob/parallel-sampling-dev/serve/mlc_serve/engine/base.py#L148 to be |
I don't think we support Besides, is there a point in allowing greedy sampling when |
) | ||
|
||
if args.use_staging_engine: | ||
engine.stop() | ||
|
||
total_num_tokens = sum( | ||
prompt_len + output_len for _, prompt_len, output_len in requests | ||
prompt_len + output_len * args.num_sequences_to_sample for _, prompt_len, output_len in requests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a subject to change, but rather some reflections - the metric that we calculate, tokens peer second, might not reflect the real number of processed tokens because we have cache eviction and we can process same request several times. I.e. we can process more requests than calculate in this total_num_tokens
. Not sure if it affects any conclusion or picture.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
imo, we should not include tokens for re-computation in this kind of benchmarking to avoid the situation where throughput looks better than it should be because of re-computed tokens. If there are many re-computed tokens to generate valid output token, that is the problem of engine and the benchmark script should reflect this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @sunggg here
sampling_params=state.sampling_params, | ||
) | ||
) | ||
cache_manager.extend( | ||
gen_seq.seq_id, | ||
len(token_ids) - gen_seq.next_start_position, | ||
prompt_counts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
understanding the changes between previous calculations of the new tokens and new implementation, I do not understand why we have here quite complex formula. It depends on the previous tokens, why do we calculate number of new tokens from the previous ones? I would understand if we have here just 1, but i do not understand why we go to the number of tokens and next position.
Can you point cases for decode where we do not have here 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously len(token_ids)
is the combined prompt and decode token counts. This diff doesn't change anything.
I think we are having future support for speculative decoding in mind here. There we may generate multiple tokens in one decoding step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This diff doesn't change anything
I absolutely agree with this.
I think we are having future support for speculative decoding in mind here
I can hardly imagine decoding of several tokens without significant modification of the logic.
Again, do not propose to change right now, but originally it has to be just extend(gen_seq.seq_id,1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes its going to be the next big project. It will be very cool.
self.current_batch[request_to_remove.request_id].num_sequences == 1 | ||
), "Evicting a multi-sequence request is not supported." | ||
|
||
# TODO(masahi): Properly support evicting a multi-sequence request |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Disappointing solution. Let's change at least the algo of looking for the request for evicting and will select from the list of requests having n == 1 and only if we cannot find such, only in this case evict requests having n > 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is temporary solution until we land the eviction. I'm okay with this as a temporary solution if we can follow-up quickly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll follow-up with this today. This is indeed a temp solution, but even if after we have a proper one, it's still good to make the best effort to avoid evicting a parallel sampling request.
): | ||
# This sequence is trying to overwrite a prompt block shared with other sequences. | ||
|
||
# TODO(masahi): The engine should take into account this additional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment is not clear. The verification and handling seems is correct. What additionally engine should take into account and in which case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking about it more, we probably don't need to worry about this since the engine conservatively assumes that each decode step can allocate one block for all sequences (correct me if I'm wrong about this). Whenever we hit this code path, we don't do allocation at L226. So we only allocate up to one block for all sequences per decode step in all code paths.
Won't we run into this assert if the request has |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @masahi for another great work! Overall, LGTM. I'm also okay with the temporary solutions you mentioned for the fast iteration, so I'd like us to merge this quickly and follow-up quickly.
self.current_batch[request_to_remove.request_id].num_sequences == 1 | ||
), "Evicting a multi-sequence request is not supported." | ||
|
||
# TODO(masahi): Properly support evicting a multi-sequence request |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is temporary solution until we land the eviction. I'm okay with this as a temporary solution if we can follow-up quickly.
num_tokens % self.block_size != 0 | ||
) | ||
|
||
if self.sliding_window: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Orthogonal to this PR, but what happens if we have more prompt tokens than sliding window size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a very common case with Mistral with long prompts. The code paths in
https://github.com/octoml/mlc-llm/blob/batch-serving/serve/mlc_serve/model/paged_cache_model.py#L65-L73
https://github.com/octoml/mlc-llm/blob/batch-serving/serve/mlc_serve/model/paged_cache_model.py#L110-L114
are specifically for such case. Prompt blocks already wrap around before any decoding happens due to circular buffering, so at each decoding step we need to carefully determine which portion of prompt blocks are still shared among n
samples.
) | ||
|
||
if args.use_staging_engine: | ||
engine.stop() | ||
|
||
total_num_tokens = sum( | ||
prompt_len + output_len for _, prompt_len, output_len in requests | ||
prompt_len + output_len * args.num_sequences_to_sample for _, prompt_len, output_len in requests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
imo, we should not include tokens for re-computation in this kind of benchmarking to avoid the situation where throughput looks better than it should be because of re-computed tokens. If there are many re-computed tokens to generate valid output token, that is the problem of engine and the benchmark script should reflect this.
This PR brings the MLC-LLM support for Android. We are now able to run LLM-based chatbot on Android phones. Co-authored-by: spectrometerHBH <[email protected]> Co-authored-by: Yaxing Cai <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Tianqi Chen <[email protected]>
The main idea is to separately manage prompt and decode tokens in the cache manager, so that the former can be shared among
n
samples in a parallel-sampling request. We need to be careful not to overwrite a shared block, which can arise whenMy implementation is very different from the one in vllm, but the section 4.4 of their paper https://arxiv.org/pdf/2309.06180.pdf provides a good background for parallel sampling. For example, the first case above, "The prompt token counts is not divisible by the block size", is solved by copying partially-shared prompt blocks to each decode sequence, as described in the paper.
The SWA + parallel sampling case is difficult and I'm not happy with my current solution. It is not clean and not optimal in terms of free block usage (I don't use "reference count" of blocks as the paper describes and vllm implements). This case is also difficult to test, so there could be more bugs. But I think this is in a reasonable state for the first cut.
Also importantly, evicting a parallel-sampling request is still not supported since restoring its KV cache entries by recompute is very challenging (vllm also doesn't support that either). I'll work on that later.
An example output from parallel sampling with SWA on an 8k prompt