Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel sampling support #106

Merged
merged 53 commits into from
Dec 14, 2023
Merged

Conversation

masahi
Copy link
Member

@masahi masahi commented Dec 11, 2023

The main idea is to separately manage prompt and decode tokens in the cache manager, so that the former can be shared among n samples in a parallel-sampling request. We need to be careful not to overwrite a shared block, which can arise when

  • The prompt token counts is not divisible by the block size
  • A prompt is longer than the window size of SWA.

My implementation is very different from the one in vllm, but the section 4.4 of their paper https://arxiv.org/pdf/2309.06180.pdf provides a good background for parallel sampling. For example, the first case above, "The prompt token counts is not divisible by the block size", is solved by copying partially-shared prompt blocks to each decode sequence, as described in the paper.

The SWA + parallel sampling case is difficult and I'm not happy with my current solution. It is not clean and not optimal in terms of free block usage (I don't use "reference count" of blocks as the paper describes and vllm implements). This case is also difficult to test, so there could be more bugs. But I think this is in a reasonable state for the first cut.

Also importantly, evicting a parallel-sampling request is still not supported since restoring its KV cache entries by recompute is very challenging (vllm also doesn't support that either). I'll work on that later.

An example output from parallel sampling with SWA on an 8k prompt

Generated 0-th sample = ' House Of The Seven Hawks The director's name is Abhsihek Saxena Who's the one who produced and directed film by'
                                                                                             
Generated 1-th sample = ' House of The Seven Hawks                                                                                                                                         
Question: Who made films on hitting the perfect 10.                                          
                                                                                                                                                                                           
 Rumbi Katedza                                                                                                                                                                             
Question'                                                                                                                                                                                  
                                                                                             
Generated 2-th sample = ' House Of The Seven Hawks                                                                                                                                         
Answers:                                                                                     
Passage 7;                                                                                   
Abhishek Saxena                                                                                                                                                                            
Abhishek Sax'                                                                                
                                      

@masahi masahi marked this pull request as ready for review December 12, 2023 06:52

def get_prompt_sequence_id(request_id: RequestId) -> SequenceId:
return SequenceId(request_id, PROMPT_SEQEUNCE_INDEX)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sunggg @elvin-n Please be aware of this convention

@masahi masahi marked this pull request as draft December 12, 2023 21:48
@masahi masahi marked this pull request as ready for review December 12, 2023 21:53
@sunggg sunggg changed the title Parallel sampling suppoort Parallel sampling support Dec 13, 2023
@psrivas2
Copy link

Should we also update the assert at https://github.com/masahi/mlc-llm/blob/parallel-sampling-dev/serve/mlc_serve/engine/base.py#L148 to be if self.best_of > self.num_sequences:

@masahi
Copy link
Member Author

masahi commented Dec 14, 2023

I don't think we support self.best_of > self.num_sequences case. We can generate best_of samples, but we don't have logic to choose num_sequnces samples from them (because right now we don't preserve logprobs after sampling).

Besides, is there a point in allowing greedy sampling when best_of == num_sequences (which your condition allows)?

)

if args.use_staging_engine:
engine.stop()

total_num_tokens = sum(
prompt_len + output_len for _, prompt_len, output_len in requests
prompt_len + output_len * args.num_sequences_to_sample for _, prompt_len, output_len in requests
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a subject to change, but rather some reflections - the metric that we calculate, tokens peer second, might not reflect the real number of processed tokens because we have cache eviction and we can process same request several times. I.e. we can process more requests than calculate in this total_num_tokens. Not sure if it affects any conclusion or picture.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo, we should not include tokens for re-computation in this kind of benchmarking to avoid the situation where throughput looks better than it should be because of re-computed tokens. If there are many re-computed tokens to generate valid output token, that is the problem of engine and the benchmark script should reflect this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @sunggg here

sampling_params=state.sampling_params,
)
)
cache_manager.extend(
gen_seq.seq_id,
len(token_ids) - gen_seq.next_start_position,
prompt_counts
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

understanding the changes between previous calculations of the new tokens and new implementation, I do not understand why we have here quite complex formula. It depends on the previous tokens, why do we calculate number of new tokens from the previous ones? I would understand if we have here just 1, but i do not understand why we go to the number of tokens and next position.

Can you point cases for decode where we do not have here 1?

Copy link
Member Author

@masahi masahi Dec 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously len(token_ids) is the combined prompt and decode token counts. This diff doesn't change anything.

I think we are having future support for speculative decoding in mind here. There we may generate multiple tokens in one decoding step.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This diff doesn't change anything

I absolutely agree with this.

I think we are having future support for speculative decoding in mind here

I can hardly imagine decoding of several tokens without significant modification of the logic.

Again, do not propose to change right now, but originally it has to be just extend(gen_seq.seq_id,1)

Copy link
Member Author

@masahi masahi Dec 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes its going to be the next big project. It will be very cool.

self.current_batch[request_to_remove.request_id].num_sequences == 1
), "Evicting a multi-sequence request is not supported."

# TODO(masahi): Properly support evicting a multi-sequence request
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disappointing solution. Let's change at least the algo of looking for the request for evicting and will select from the list of requests having n == 1 and only if we cannot find such, only in this case evict requests having n > 1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is temporary solution until we land the eviction. I'm okay with this as a temporary solution if we can follow-up quickly.

Copy link
Member Author

@masahi masahi Dec 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll follow-up with this today. This is indeed a temp solution, but even if after we have a proper one, it's still good to make the best effort to avoid evicting a parallel sampling request.

):
# This sequence is trying to overwrite a prompt block shared with other sequences.

# TODO(masahi): The engine should take into account this additional
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is not clear. The verification and handling seems is correct. What additionally engine should take into account and in which case?

Copy link
Member Author

@masahi masahi Dec 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about it more, we probably don't need to worry about this since the engine conservatively assumes that each decode step can allocate one block for all sequences (correct me if I'm wrong about this). Whenever we hit this code path, we don't do allocation at L226. So we only allocate up to one block for all sequences per decode step in all code paths.

@psrivas2
Copy link

psrivas2 commented Dec 14, 2023

Besides, is there a point in allowing greedy sampling when best_of == num_sequences (which your condition allows)?

Won't we run into this assert if the request has n >= 2 and temperature = 0.0? I understand that this is not a realistic use case though but the user might be surprised if this request fails.

Copy link
Member

@sunggg sunggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @masahi for another great work! Overall, LGTM. I'm also okay with the temporary solutions you mentioned for the fast iteration, so I'd like us to merge this quickly and follow-up quickly.

self.current_batch[request_to_remove.request_id].num_sequences == 1
), "Evicting a multi-sequence request is not supported."

# TODO(masahi): Properly support evicting a multi-sequence request
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is temporary solution until we land the eviction. I'm okay with this as a temporary solution if we can follow-up quickly.

num_tokens % self.block_size != 0
)

if self.sliding_window:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Orthogonal to this PR, but what happens if we have more prompt tokens than sliding window size?

Copy link
Member Author

@masahi masahi Dec 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a very common case with Mistral with long prompts. The code paths in

https://github.com/octoml/mlc-llm/blob/batch-serving/serve/mlc_serve/model/paged_cache_model.py#L65-L73
https://github.com/octoml/mlc-llm/blob/batch-serving/serve/mlc_serve/model/paged_cache_model.py#L110-L114

are specifically for such case. Prompt blocks already wrap around before any decoding happens due to circular buffering, so at each decoding step we need to carefully determine which portion of prompt blocks are still shared among n samples.

)

if args.use_staging_engine:
engine.stop()

total_num_tokens = sum(
prompt_len + output_len for _, prompt_len, output_len in requests
prompt_len + output_len * args.num_sequences_to_sample for _, prompt_len, output_len in requests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo, we should not include tokens for re-computation in this kind of benchmarking to avoid the situation where throughput looks better than it should be because of re-computed tokens. If there are many re-computed tokens to generate valid output token, that is the problem of engine and the benchmark script should reflect this.

@sunggg sunggg merged commit c767814 into octoml:batch-serving Dec 14, 2023
1 check passed
Lunderberg pushed a commit to Lunderberg/mlc-llm that referenced this pull request Jan 30, 2024
This PR brings the MLC-LLM support for Android. We are now able to run
LLM-based chatbot on Android phones.

Co-authored-by: spectrometerHBH <[email protected]>
Co-authored-by: Yaxing Cai <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Tianqi Chen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants