diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index b66dea3479..a689e43888 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -19,8 +19,8 @@ class RawLogprobsInfo: current_token_id: int current_logprob: float - top_token_ids: Optional[np.array] - top_logprobs: Optional[np.array] + top_token_ids: Optional[np.ndarray] + top_logprobs: Optional[np.ndarray] RawLogprobsInfos = List[Optional[RawLogprobsInfo]] diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 675e97e173..4a02ce2f60 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -22,9 +22,12 @@ from .model_module import ( DecodeRequest, PrefillRequest, + EvalMultiQueryRequest, + EvictedTokens, ConversationTemplate, KVCacheManager, ModelModule, + RequestType, TextGenerator, Tokenizer as TokenizerP, ) @@ -226,26 +229,71 @@ def update_sequence( def get_requests_to_process( current_states: list[RequestState], cache_manager: KVCacheManager -) -> Tuple[list[Union[PrefillRequest, DecodeRequest]], bool, int]: - requests: list[Union[PrefillRequest, DecodeRequest]] = [] +) -> Tuple[list[RequestType], bool, int]: + requests: list[RequestType] = [] # TODO: consider having hybrid batch if the underlying attention kernel supports # mixing prefill and decode. is_prompt_batch = any(not state.is_prefilled for state in current_states) token_counts = 0 + is_evicted_parallel_sampling_request = ( + lambda state: not state.is_prefilled + and state.num_sequences > 1 + and any( + len(gen_seq.generated_token_ids) > 0 + for gen_seq in state.generation_sequences + ) + ) + if is_prompt_batch: for state in current_states: - if not state.is_prefilled: + if is_evicted_parallel_sampling_request(state): requests.append( + PrefillRequest( + request_id=state.request_id, + token_ids=state.prompt_token_ids, + num_sequence=state.num_sequences, + sampling_params=state.sampling_params, + ) + ) + + token_counts += len(state.prompt_token_ids) + + for gen_seq in state.generation_sequences: + requests.append( + EvalMultiQueryRequest( + sequence_id=gen_seq.seq_id, + num_past_tokens=state.prompt_len, + queries=EvictedTokens(gen_seq.generated_token_ids), + sampling_params=state.sampling_params, + ) + ) + cache_manager.extend( + gen_seq.seq_id, + len(gen_seq.generated_token_ids) + 1, + ) + + # TODO(masahi): How to account for token counts in EvalMultiQueryRequest in + # Prometheus metric? + elif not state.is_prefilled: + if ( + state.num_sequences == 1 + and state.generation_sequences[0].generated_token_ids + ): # generated_token_ids is added for the case where the request is # recovering from cache eviction. - # TODO(masahi): This needs an update when we support evicting - # a parallel-sampling request. + token_ids = ( + state.prompt_token_ids + + state.generation_sequences[0].generated_token_ids + ) + else: + token_ids = state.prompt_token_ids + + requests.append( PrefillRequest( request_id=state.request_id, - token_ids=state.prompt_token_ids - + state.generation_sequences[0].generated_token_ids, + token_ids=token_ids, num_sequence=state.num_sequences, sampling_params=state.sampling_params, ) @@ -392,16 +440,28 @@ def evict_request(self, cancell_callback: Callable[[RequestId], None]) -> int: candidate_victims = parallel_sample_requests request_to_remove = min(candidate_victims, key=lambda s: s.num_total_tokens) - - # TODO(masahi): Properly support evicting a multi-sequence request - if self.current_batch[request_to_remove.request_id].num_sequences != 1: - cancell_callback(request_to_remove.request_id) - self.remove_request_from_batch(request_to_remove.request_id) - LOG.warn( - "Preempting a multi-sequence request is currently not supported," - f" cancelling request '{request_to_remove.request_id}'", + victim_state = self.current_batch[request_to_remove.request_id] + + if victim_state.num_sequences != 1: + prev_generated_token_counts = sum( + [ + len(gen_seq.generated_token_ids) + for gen_seq in victim_state.generation_sequences + ] ) - continue + # We could allow evicting and restoring a parallel-sampling request whose prev_generated_token_counts + # is > max_num_batched_tokens, by making the model split a list of EvalMultiQuery requests into parts, + # so that an inference on each part can be done with the max_num_batched_tokens budget. + # But this introduces an undesirable coupling between the engine and the model. + if prev_generated_token_counts >= self.max_num_batched_tokens: + cancell_callback(request_to_remove.request_id) + self.remove_request_from_batch(request_to_remove.request_id) + LOG.warn( + f"Cancelling a parallel-sampling request '{request_to_remove.request_id}'" + f"since it has generated more than {self.max_num_batched_tokens} tokens in total" + "and currently we do not support preempting such request.", + ) + continue self.remove_request_from_batch(request_to_remove.request_id) request_to_remove.is_prefilled = False @@ -446,14 +506,27 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]: gen_seq.next_start_position = ( num_new_batched_tokens ) = num_tokens = self.max_num_batched_tokens + + num_kv_slots_needed = min(num_tokens, self.model_context_window_size) else: - # Evicting and recovering multi-sequence requests is not supported for now. - assert all( - gen_seq.next_start_position == state.prompt_len - for gen_seq in state.generation_sequences + prev_generated_token_counts = sum( + [ + len(gen_seq.generated_token_ids) + for gen_seq in state.generation_sequences + ] ) + + # Restoring an evicted parallel-sampling request with sliding-window attention is + # difficult to reason about, so we use crude upper bounds below for now. num_tokens = state.prompt_len - num_new_batched_tokens += num_tokens + num_kv_slots_needed = state.prompt_len + prev_generated_token_counts + # Restoring an evicted parallel-sampling request is done by separate + # Prefill and MultiQuery requests. The maximum below is an upper bound on the + # batch size increase due to this request. + # TODO(masahi): Prefill and EvalMultiQuery requests are handled separately by the model. + # So comparing the sum of their batched token counts against max_num_batched_tokens + # is not optimal. + num_new_batched_tokens += max(state.prompt_len, prev_generated_token_counts) if num_new_batched_tokens > self.max_num_batched_tokens: LOG.debug( @@ -465,7 +538,6 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]: # We make sure that the KV cache will have enough free space for this request to proceed # decoding for at least self.max_decode_steps steps. # See the comment in check_prompt_too_long for the optimization involving the window size. - num_kv_slots_needed = min(num_tokens, self.model_context_window_size) if (self.cache_manager.get_free_space() - num_kv_slots_needed) / ( len(self.current_batch) + 1 ) < self.max_decode_steps * state.num_sequences: @@ -477,7 +549,6 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]: return None self.queue.popleft() - # TODO parallel sampling: Need update here when evicting multi-sequence requests is supported. self.cache_manager.allocate(state.request_id, num_tokens, state.num_sequences) self.current_batch[state.request_id] = state diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 076f9882fe..a5b86d69b9 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -40,6 +40,36 @@ class DecodeRequest: sampling_params: SamplingParams +@dataclass +class DraftTokens: + token_ids: List[int] + + @property + def num_tokens(self): + return len(self.token_ids) + + +@dataclass +class EvictedTokens: + token_ids: List[int] + + @property + def num_tokens(self): + return len(self.token_ids) + + +@dataclass +class EvalMultiQueryRequest: + sequence_id: SequenceId + num_past_tokens: int + queries: Union[DraftTokens, EvictedTokens] + sampling_params: SamplingParams + + +RequestType = Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest] +RequestsType = Sequence[RequestType] + + @dataclass class TextGenerationResult: """ @@ -125,7 +155,7 @@ class TextGenerator(Protocol): def generate( self, - requests: Sequence[Union[PrefillRequest, DecodeRequest]], + requests: Sequence[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]], kv_cache, ) -> List[TextGenerationResult]: """ diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 1028bd555a..ab5f437667 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -9,10 +9,21 @@ from ..engine import ( SamplingType, SamplingParams, + get_prompt_sequence_id, LOGPROB_TOP_K_MAX, + PROMPT_SEQEUNCE_INDEX, RawLogprobsInfo, RawLogprobsInfos, + SequenceId, ) +from ..engine.model_module import ( + PrefillRequest, + EvalMultiQueryRequest, + RequestType, + RequestsType, + TextGenerationResult, +) + LOG = structlog.stdlib.get_logger(__name__) @@ -65,8 +76,8 @@ def get_raw_logprob_info( top_logprobs, top_tokens = torch.topk( logprobs, k=top_logprobs_num, dim=-1, largest=True, sorted=True ) - top_tokens=top_tokens.cpu().numpy() - top_logprobs=top_logprobs.cpu().numpy() + top_tokens = top_tokens.cpu().numpy() + top_logprobs = top_logprobs.cpu().numpy() # Set to raw logprob info return RawLogprobsInfo( @@ -106,7 +117,7 @@ def get_raw_logprob_infos( logits: torch.Tensor, token_ids: torch.Tensor, ) -> RawLogprobsInfos: - for (i, ind, top_logprobs) in indices: + for i, ind, top_logprobs in indices: logprob_infos[i] = get_raw_logprob_info( logits[ind], token_ids[ind], @@ -292,6 +303,111 @@ def _is_safe_to_sample(prob_like): return res, check_logprob_infos(logprob_infos) +def update_tokens_frequency( + request: RequestType, + new_token: int +): + if not new_token in request.sampling_params.appeared_tokens_freq: + request.sampling_params.appeared_tokens_freq[new_token] = 0 + request.sampling_params.appeared_tokens_freq[new_token] += 1 + + +def append_text_gen_res( + outputs: List[TextGenerationResult], + request: RequestType, + new_token: List[int], + sequence_id: SequenceId, + logprob_info: Optional[RawLogprobsInfos], + err_msg: Optional[str]=None, +) -> List[TextGenerationResult]: + if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: + assert isinstance(request, PrefillRequest) + for seq_id in range(request.num_sequence): # type: ignore + outputs.append( + TextGenerationResult( + sequence_id=SequenceId(sequence_id.request_id, seq_id), + generated_tokens=new_token, + error=err_msg, + logprob_info=logprob_info, + ) + ) + else: + outputs.append( + TextGenerationResult( + sequence_id=sequence_id, + generated_tokens=new_token, + error=err_msg, + logprob_info=logprob_info, + ) + ) + return outputs + + +def sample_from_logits( + logits: Union[tvm.nd.NDArray, torch.Tensor], + sequence_ids: List[SequenceId], + requests: RequestsType, + vocab_size, +) -> List[TextGenerationResult]: + assert logits.shape[0] == len(requests) + + sampling_params = [req.sampling_params for req in requests] + outputs: List[TextGenerationResult] = [] + + try: + next_tokens, logprob_infos = sample(logits, sampling_params, vocab_size) + assert next_tokens is not None + for i, (sequence_id, new_token) in enumerate(zip(sequence_ids, next_tokens)): + update_tokens_frequency(requests[i], new_token) + outputs = append_text_gen_res( + outputs, + requests[i], + [new_token], + sequence_id, + get_logprob_infos(i, logprob_infos), + ) + + return outputs + except RuntimeError: + # Fallback to per-token sampling in case some logits values are corrupted. + err_msg = ( + "Error from sampling: probability tensor contains either `inf`, `nan`" + " or element < 0" + ) + + for i, (sequence_id, logits_per_token, sampling_param) in enumerate( + zip(sequence_ids, torch.from_dlpack(logits), sampling_params) + ): + maybe_new_token, logprob_infos = sample( + torch.unsqueeze(logits_per_token, 0), + [sampling_param], + vocab_size, + check_safety=True, + ) + + if maybe_new_token is not None: + new_token = maybe_new_token[0] + update_tokens_frequency(requests[i], new_token) + outputs = append_text_gen_res( + outputs, + requests[i], + [new_token], + sequence_id, + get_logprob_infos(0, logprob_infos), + ) + else: + outputs = append_text_gen_res( + outputs, + requests[i], + [], # new_token + sequence_id, + get_logprob_infos(0, logprob_infos), + err_msg, + ) + + return outputs + + def prepare_inputs( sequence_ids, all_token_ids, @@ -378,3 +494,92 @@ def _pad_to_max(x: List[int], max_len: int) -> List[int]: indices_within_window, block_tables, ) + + +def prepare_multi_query_decode_inputs( + requests: List[EvalMultiQueryRequest], + all_slot_mappings, + sliding_window, + dev, +): + seq_lens = [] + query_lens = [] + input_ids = [] + slot_mapping = [] + past_slot_mapping = [] + positions = [] + permute_map = [] + + query_offset = sum([request.num_past_tokens for request in requests]) + past_offset = 0 + + for request in requests: + num_queries = request.queries.num_tokens + query_lens.append(num_queries) + input_ids += request.queries.token_ids + positions += [request.num_past_tokens + i for i in range(num_queries)] + + prompt_seq_id = get_prompt_sequence_id(request.sequence_id.request_id) + prompt_slot_mappings = all_slot_mappings[prompt_seq_id] + + if sliding_window and request.num_past_tokens + num_queries >= sliding_window: + seq_lens.append(sliding_window) + prompt_and_decode_slot_mappings = ( + prompt_slot_mappings + all_slot_mappings[request.sequence_id] + ) + past_slot_mapping += prompt_and_decode_slot_mappings[ + request.num_past_tokens + - (sliding_window - num_queries) : request.num_past_tokens + ] + slot_mapping += prompt_and_decode_slot_mappings[ + request.num_past_tokens : request.num_past_tokens + num_queries + ] + else: + seq_lens.append(request.num_past_tokens + num_queries) + + if request.num_past_tokens < len(prompt_slot_mappings): + raise RuntimeError( + "For EvalMultiQueryRequest, the number of past tokens" + "smaller than the prompt length is not supported for now." + ) + elif request.num_past_tokens == len(prompt_slot_mappings): + # The case for restoring an evicted parallel-sampling request + past_slot_mapping += prompt_slot_mappings + slot_mapping += all_slot_mappings[request.sequence_id][:num_queries] + else: + query_begin_offset = request.num_past_tokens - len(prompt_slot_mappings) + past_slot_mapping += ( + prompt_slot_mappings + + all_slot_mappings[request.sequence_id][:query_begin_offset] + ) + slot_mapping += all_slot_mappings[request.sequence_id][ + query_begin_offset : query_begin_offset + num_queries + ] + + permute_map += list( + range(past_offset, past_offset + request.num_past_tokens) + ) + list(range(query_offset, query_offset + num_queries)) + + query_offset += num_queries + past_offset += request.num_past_tokens + + input_ids = tvm.nd.array(np.array(input_ids, dtype="int32"), dev) + positions = tvm.nd.array(np.array(positions, dtype="int32"), dev) + seq_lens = tvm.nd.array(np.array(seq_lens, dtype="int32"), dev) + slot_mapping = tvm.nd.array(np.array(slot_mapping, dtype="int32"), dev) + + query_lens = tvm.nd.array(np.array(query_lens, dtype="int32"), dev) + # TODO(masahi): These inputs need to be replaced by block_table when a proper attention kernel + # becomes available. + past_slot_mapping = tvm.nd.array(np.array(past_slot_mapping, dtype="int32"), dev) + permute_map = tvm.nd.array(np.array(permute_map, dtype="int32"), dev) + + return ( + input_ids, + positions, + seq_lens, + slot_mapping, + query_lens, + past_slot_mapping, + permute_map, + ) diff --git a/serve/mlc_serve/model/paged_cache_manager.py b/serve/mlc_serve/model/paged_cache_manager.py index e3c466a011..c1bdbc1f7f 100644 --- a/serve/mlc_serve/model/paged_cache_manager.py +++ b/serve/mlc_serve/model/paged_cache_manager.py @@ -172,50 +172,60 @@ def set_size(self, sequence_ids: List[SequenceId], target_sizes: List[int]): elif id in self.kv_cache_info.decode_block_tables: decode_block_table = self.kv_cache_info.decode_block_tables[id] - if len(decode_block_table) < num_needed_block: + while len(decode_block_table) < num_needed_block: # Need to allocate a new block for this request - assert len(decode_block_table) + 1 == num_needed_block assert len(self.free_blocks) > 0 decode_block_table.append(self.free_blocks.pop()) - pos = size - 1 + prompt_seq_id = get_prompt_sequence_id(id.request_id) + allocated_slot_counts = len( + self.kv_cache_info.slot_mappings[prompt_seq_id] + ) + len(self.kv_cache_info.slot_mappings[id]) - def get_block_circular_index(token_pos): - assert self.block_sliding_window - return (token_pos // self.block_size) % self.block_sliding_window + for current_size in range(allocated_slot_counts + 1, size + 1): + pos = current_size - 1 - if ( - decode_block_table.prompt_shared - and self.sliding_window - and size >= self.sliding_window - ): - # Parallel sampling + SWA case - if decode_block_table.prompt_cursor == get_block_circular_index( - pos + def get_block_circular_index(token_pos): + assert self.block_sliding_window + return ( + token_pos // self.block_size + ) % self.block_sliding_window + + if ( + decode_block_table.prompt_shared + and self.sliding_window + and current_size >= self.sliding_window ): - # This sequence is trying to overwrite a prompt block shared with other sequences. - assert ( - len(self.free_blocks) > 0 - ), "No more free block in the cache." - - block_number = self.free_blocks.pop() - # Add a new decode block and advance the prompt cursor - decode_block_table.replace_head_prompt_block_with(block_number) - else: - # Write to the decode block allocated above - block_number = decode_block_table[-1] + # Parallel sampling + SWA case + if ( + decode_block_table.prompt_cursor + == get_block_circular_index(pos) + ): + # This sequence is trying to overwrite a prompt block shared with other sequences. + assert ( + len(self.free_blocks) > 0 + ), "No more free block in the cache." + + block_number = self.free_blocks.pop() + # Add a new decode block and advance the prompt cursor + decode_block_table.replace_head_prompt_block_with( + block_number + ) + else: + # Write to the decode block allocated above + block_number = decode_block_table[-1] - else: - if self.block_sliding_window: - index = get_block_circular_index(pos) else: - index = -1 + if self.block_sliding_window: + index = get_block_circular_index(pos) + else: + index = -1 - block_number = decode_block_table[index] + block_number = decode_block_table[index] - block_offset = pos % self.block_size - slot = block_number * self.block_size + block_offset - self.kv_cache_info.slot_mappings[id].append(slot) + block_offset = pos % self.block_size + slot = block_number * self.block_size + block_offset + self.kv_cache_info.slot_mappings[id].append(slot) elif id not in self.kv_cache_info.prompt_block_tables: assert ( diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 7daf3336f4..433ca2baa3 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -1,6 +1,6 @@ from pathlib import Path import structlog -from typing import List, Union +from typing import List from .base import get_model_artifact_config from .paged_cache_manager import CacheManager @@ -12,6 +12,8 @@ DecodeRequest, ModelModule, PrefillRequest, + EvalMultiQueryRequest, + RequestType, TextGenerationResult, TextGenerator, ) @@ -24,17 +26,44 @@ def __init__(self, model: TextGenerator): self.model = model def generate( - self, requests: List[Union[PrefillRequest, DecodeRequest]], kv_cache + self, + requests: List[RequestType], + kv_cache, ) -> List[TextGenerationResult]: - prefill_requests = [r for r in requests if isinstance(r, PrefillRequest)] - decode_requests = [r for r in requests if isinstance(r, DecodeRequest)] + prefill_requests = [] + decode_requests = [] + multi_query_decode_requests = [] + multi_query_decode_request_ids = set() + + for r in requests: + if isinstance(r, PrefillRequest): + prefill_requests.append(r) + elif isinstance(r, DecodeRequest): + decode_requests.append(r) + elif isinstance(r, EvalMultiQueryRequest): + multi_query_decode_requests.append(r) + multi_query_decode_request_ids.add(r.sequence_id.request_id) out = [] + if prefill_requests: - out.extend(self.model.generate(prefill_requests, kv_cache)) + prefill_res = self.model.generate(prefill_requests, kv_cache) + + if not multi_query_decode_requests: + out.extend(prefill_res) + else: + # Prefill requests from restoration of evicted parallel-sampling requests + # must not return outputs. + for res in prefill_res: + if res.sequence_id.request_id not in multi_query_decode_request_ids: + out.append(res) + if decode_requests: out.extend(self.model.generate(decode_requests, kv_cache)) + if multi_query_decode_requests: + out.extend(self.model.generate(multi_query_decode_requests, kv_cache)) + return out diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index f1950a7417..6c37303c9f 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -1,6 +1,6 @@ import math import os -from typing import List, Optional, Union, Tuple, Sequence +from typing import List, Tuple import structlog import numpy as np @@ -12,22 +12,22 @@ from .base import ModelArtifactConfig from .paged_cache_manager import KVCacheInfo, CacheManager from .model_common import ( - sample, + sample_from_logits, prepare_inputs, - get_logprob_infos, + prepare_multi_query_decode_inputs, get_num_cache_blocks, ) from ..engine import ( - PROMPT_SEQEUNCE_INDEX, - RawLogprobsInfos, - SequenceId, get_prompt_sequence_id, MLCServeEngineConfig, ) from ..engine.model_module import ( DecodeRequest, + DraftTokens, + EvalMultiQueryRequest, PrefillRequest, + RequestsType, TextGenerationResult, TextGenerator, ) @@ -212,18 +212,92 @@ def profile_memory_usage(self, seq_lens): return self.get_used_memory() + def generate_multi_query( + self, + requests: List[EvalMultiQueryRequest], + cache: KVCacheInfo, + ) -> List[TextGenerationResult]: + sequence_ids = [] + last_query_offsets: List[int] = [] + for request in requests: + assert not isinstance(request.queries, DraftTokens) + sequence_ids.append(request.sequence_id) + + if len(last_query_offsets) == 0: + last_query_offsets.append(request.queries.num_tokens - 1) + else: + last_query_offsets.append( + last_query_offsets[-1] + request.queries.num_tokens + ) + + ( + input_ids, + positions, + seq_lens, + slot_mapping, + query_lens, + past_slot_mapping, + permute_map, + ) = prepare_multi_query_decode_inputs( + requests, + cache.slot_mappings, + None, + self.dev, + ) + + torch.cuda.nvtx.range_push(f"forward multi-query decode {input_ids.shape}") + + if self.disco_session: + input_ids = copy_to_worker_0(self.disco_session, input_ids) + positions = copy_to_worker_0(self.disco_session, positions) + seq_lens = copy_to_worker_0(self.disco_session, seq_lens) + slot_mapping = copy_to_worker_0(self.disco_session, slot_mapping) + query_lens = copy_to_worker_0(self.disco_session, query_lens) + past_slot_mapping = copy_to_worker_0(self.disco_session, past_slot_mapping) + permute_map = copy_to_worker_0(self.disco_session, permute_map) + + out = self.mod["evaluate_multi_query"]( + input_ids, + positions, + seq_lens, + self.cache_blocks, + slot_mapping, + query_lens, + past_slot_mapping, + permute_map, + self.params, + ) + + if self.disco_session: + logits, _ = out.debug_get_from_remote(0) + else: + logits = out[0] + + torch.cuda.synchronize() + torch.cuda.nvtx.range_pop() + + last_query_logits = torch.from_dlpack(logits)[last_query_offsets] + + return sample_from_logits( + last_query_logits, sequence_ids, requests, self.vocab_size + ) + def generate( self, - requests: Sequence[Union[PrefillRequest, DecodeRequest]], + requests: RequestsType, cache: KVCacheInfo, ) -> List[TextGenerationResult]: if len(requests) == 0: return [] is_prefill = isinstance(requests[0], PrefillRequest) + is_multi_query_decode = isinstance(requests[0], EvalMultiQueryRequest) + + if is_multi_query_decode: + return self.generate_multi_query(requests, cache) # type: ignore + # Prefill or decode all_token_ids = [] - sampling_params = [] sequence_ids = [] prompt_lens = [] num_sequences = [] @@ -234,13 +308,12 @@ def generate( for request in requests: if isinstance(request, PrefillRequest): sequence_ids.append(get_prompt_sequence_id(request.request_id)) - num_sequences.append(request.num_sequence) - else: + elif isinstance(request, DecodeRequest): sequence_ids.append(request.sequence_id) prompt_lens.append(request.prompt_token_counts) + assert not isinstance(request, EvalMultiQueryRequest) all_token_ids.append(request.token_ids) - sampling_params.append(request.sampling_params) ( input_ids, @@ -352,108 +425,7 @@ def generate( # TODO(masahi, yelite): Proper logic for handling multi-query logits (speculative decoding). return [] - try: - next_tokens, logprob_infos = sample(logits, sampling_params, self.vocab_size) - assert next_tokens is not None - outputs = [] - for i, (sequence_id, new_token) in enumerate( - zip(sequence_ids, next_tokens) - ): - if not new_token in requests[i].sampling_params.appeared_tokens_freq: - requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 - requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - for seq_id in range(num_sequences[i]): - outputs.append( - TextGenerationResult( - sequence_id=SequenceId(sequence_id.request_id, seq_id), - generated_tokens=[new_token], - error=None, - logprob_info=get_logprob_infos(i, logprob_infos), - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[new_token], - error=None, - logprob_info=get_logprob_infos(i, logprob_infos), - ) - ) - - return outputs - except RuntimeError: - # Fallback to per-token sampling in case some logits values are corrupted. - outputs = [] - err_msg = ( - "Error from sampling: probability tensor contains either `inf`, `nan`" - " or element < 0" - ) - - for i, (sequence_id, logits_per_token, sampling_param) in enumerate( - zip(sequence_ids, torch.from_dlpack(logits), sampling_params) - ): - maybe_new_token, logprob_infos = sample( - torch.unsqueeze(logits_per_token, 0), - [sampling_param], - self.vocab_size, - check_safety=True, - ) - - if maybe_new_token is not None: - new_token = maybe_new_token[0] - if ( - not new_token - in requests[i].sampling_params.appeared_tokens_freq - ): - requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 - requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - for seq_id in range(num_sequences[i]): - outputs.append( - TextGenerationResult( - sequence_id=SequenceId( - sequence_id.request_id, seq_id - ), - generated_tokens=[new_token], # type: ignore - error=None, - logprob_info=get_logprob_infos(0, logprob_infos), - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[new_token], # type: ignore - error=None, - logprob_info=get_logprob_infos(0, logprob_infos), - ) - ) - else: - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - for seq_id in range(num_sequences[i]): - outputs.append( - TextGenerationResult( - sequence_id=SequenceId( - sequence_id.request_id, seq_id - ), - generated_tokens=[], - error=err_msg, - logprob_info=get_logprob_infos(0, logprob_infos), - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[], - error=err_msg, - logprob_info=get_logprob_infos(0, logprob_infos), - ) - ) - - return outputs + return sample_from_logits(logits, sequence_ids, requests, self.vocab_size) def init_tvm_model(