diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 3dd52269b3..21fbb9d044 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -16,6 +16,7 @@ RequestState, SequenceId, StoppingCriteria, + RawLogprobsInfo, ) from .model_module import ( DecodeRequest, @@ -80,8 +81,8 @@ def detokenize_incrementally( prompt_tokens: list[int], generation_sequence: GenerationSequence, tokenizer: TokenizerP, - new_token_id=None, - skip_special_tokens=False, + new_token_id: Optional[int] = None, + skip_special_tokens: bool = False, ) -> str: # tokenizer.decode() is similar to doing convert_tokens_to_string(convert_ids_to_tokens(token_ids)) # in this function, we separate these two steps. @@ -89,11 +90,10 @@ def detokenize_incrementally( new_token_id = ( generation_sequence.generated_token_ids[-1] if not is_logprob else new_token_id ) - # This is the first iteration for this sequence if generation_sequence.prev_tokens is None: # TODO(masahi): Figure out a way to remove this concat - new_tokens: Union[str, List[str]] = tokenizer.convert_ids_to_tokens( + new_tokens: List[str] = tokenizer.convert_ids_to_tokens( # type: ignore prompt_tokens + generation_sequence.generated_token_ids ) output_tokens = new_tokens @@ -109,7 +109,7 @@ def detokenize_incrementally( prefix_end_offset = max(len(output_tokens) - 1, 0) else: # Put new_token_id in a list so skip_special_tokens is respected - new_tokens: Union[str, List[str]] = tokenizer.convert_ids_to_tokens( + new_tokens: List[str] = tokenizer.convert_ids_to_tokens( # type: ignore [new_token_id] ) output_tokens = generation_sequence.prev_tokens + new_tokens @@ -170,12 +170,12 @@ def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended): def prepare_logprob( - logprob_info, - delta, + logprob_info: Optional[List[RawLogprobsInfo]], + delta: str, gen_seq: GenerationSequence, - prompt_token_ids, - tokenizer, -): + prompt_token_ids: List[int], + tokenizer: TokenizerP, +) -> List[Optional[LogprobsContent]]: if logprob_info is None: return [] @@ -219,7 +219,7 @@ def prepare_output( logprob_info, tokenizer: TokenizerP, stopping_criteria: StoppingCriteria, -) -> str: +) -> Tuple[str, List[Optional[LogprobsContent]]]: gen_seq.next_start_position = len(prompt_token_ids) + len( gen_seq.generated_token_ids ) @@ -227,7 +227,7 @@ def prepare_output( delta = detokenize_incrementally(prompt_token_ids, gen_seq, tokenizer) gen_seq.output_text += delta - out_logprob_info = prepare_logprob( + out_logprob_info: List[Optional[LogprobsContent]] = prepare_logprob( logprob_info, delta, gen_seq, prompt_token_ids, tokenizer )