diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index 5c9373c282..f8b9dce3c1 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -271,8 +271,8 @@ class GenerationSequence: generated_token_ids: list[int] next_start_position: int output_text: str - prefix_offset: int = 0 - read_offset: int = 0 + prefix_begin_offset: int = 0 + new_prefix_end_offset: int = 0 prev_tokens: Optional[List[str]] = None is_finished: bool = False diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index fe9cf6eebe..09b7410fc3 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -68,66 +68,62 @@ def get_new_request_state( ) -def decode_last_output( +# Based on vllm: https://github.com/vllm-project/vllm/pull/984 +def detokenize_incrementally( prompt_tokens: list[int], generation_sequence: GenerationSequence, tokenizer: Tokenizer, - skip_special_tokens=True, + skip_special_tokens=False, ) -> str: new_token_id = generation_sequence.generated_token_ids[-1] # This is the first iteration for this sequence if generation_sequence.prev_tokens is None: # TODO(masahi): Figure out a way to remove this concat - all_input_ids = prompt_tokens + generation_sequence.generated_token_ids - new_tokens = tokenizer._tokenizer.convert_ids_to_tokens( - all_input_ids, skip_special_tokens=skip_special_tokens + new_tokens = tokenizer.convert_ids_to_tokens( + prompt_tokens + generation_sequence.generated_token_ids ) output_tokens = new_tokens # 5 is an arbitrary value that should work for all # tokenizers (bigger = more conservative). # Subtract 1 extra to account for the generated token. - prefix_offset = max(len(output_tokens) - 6, 0) + prefix_begin_offset = max(len(output_tokens) - 6, 0) - if skip_special_tokens and new_token_id in tokenizer._tokenizer.all_special_ids: - read_offset = max(len(output_tokens), 0) + if skip_special_tokens and new_token_id in tokenizer.all_special_ids: + prefix_end_offset = max(len(output_tokens), 0) else: - read_offset = max(len(output_tokens) - 1, 0) + prefix_end_offset = max(len(output_tokens) - 1, 0) else: # Put new_token_id in a list so skip_special_tokens is respected - new_tokens = tokenizer._tokenizer.convert_ids_to_tokens( - [new_token_id], skip_special_tokens=skip_special_tokens - ) + new_tokens = tokenizer.convert_ids_to_tokens([new_token_id]) output_tokens = generation_sequence.prev_tokens + new_tokens - prefix_offset = generation_sequence.prefix_offset - read_offset = generation_sequence.read_offset + prefix_begin_offset = generation_sequence.prefix_begin_offset + prefix_end_offset = generation_sequence.prefix_end_offset - assert tokenizer._tokenizer.is_fast + assert tokenizer.is_fast - prefix_text = tokenizer._tokenizer.convert_tokens_to_string( - output_tokens[prefix_offset:read_offset] - ) - new_text = tokenizer._tokenizer.convert_tokens_to_string( - output_tokens[prefix_offset:] + prefix_text = tokenizer.convert_tokens_to_string( + output_tokens[prefix_begin_offset:prefix_end_offset] ) + new_text = tokenizer.convert_tokens_to_string(output_tokens[prefix_begin_offset:]) if len(new_text) > len(prefix_text) and not new_text.endswith("�"): # utf-8 char at the end means it's a potential unfinished byte sequence # from byte fallback tokenization. # If it's in the middle, it's probably a real invalid id generated # by the model - new_prefix_offset = read_offset - new_read_offset = len(output_tokens) + new_prefix_begin_offset = prefix_end_offset + new_prefix_end_offset = len(output_tokens) delta = new_text[len(prefix_text) :] else: - new_prefix_offset = prefix_offset - new_read_offset = read_offset + new_prefix_begin_offset = prefix_begin_offset + new_prefix_end_offset = prefix_end_offset delta = "" - generation_sequence.prefix_offset = new_prefix_offset - generation_sequence.read_offset = new_read_offset + generation_sequence.prefix_begin_offset = new_prefix_begin_offset + generation_sequence.prefix_end_offset = new_prefix_end_offset if generation_sequence.prev_tokens is None: generation_sequence.prev_tokens = new_tokens else: @@ -136,41 +132,6 @@ def decode_last_output( return delta -""" -def decode_last_output( - prompt_tokens: list[int], - generation_sequence: GenerationSequence, - tokenizer: Tokenizer, -) -> str: - if len(generation_sequence.output_text): - LOG.info( - f"T:{generation_sequence.output_text}({len({generation_sequence.output_text})})" - ) - prefix_idx = max(0, generation_sequence.next_start_position - 6) - else: - LOG.info( - f"N: {generation_sequence.output_text}({len({generation_sequence.output_text})})" - ) - prefix_idx = generation_sequence.next_start_position - - # TODO(masahi): Figure out a way to remove this concat - token_ids = prompt_tokens + generation_sequence.generated_token_ids - LOG.info( - f"{generation_sequence.next_start_position}, {len(token_ids)}, {prefix_idx}" - ) - - if prefix_idx == 0: - return tokenizer.decode(token_ids) - - prefix = tokenizer.decode( - token_ids[prefix_idx : generation_sequence.next_start_position] - ) - full = tokenizer.decode(token_ids[prefix_idx:]) - - return full[len(prefix) :] -""" - - def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended): if stopping_criteria.stop_sequences: for t in stopping_criteria.stop_sequences: @@ -201,7 +162,7 @@ def update_sequence( gen_seq.generated_token_ids ) gen_seq.generated_token_ids.extend(new_token_ids) - delta = decode_last_output(prompt_token_ids, gen_seq, tokenizer) + delta = detokenize_incrementally(prompt_token_ids, gen_seq, tokenizer) gen_seq.output_text += delta gen_seq.output_text, delta, gen_seq.is_finished = check_stopping_sequences( diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index df26e20403..b546f6cf7b 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -134,7 +134,6 @@ class Tokenizer(Protocol): def encode(self, text: str) -> list[int]: pass - # TODO: Incremental decoding def decode(self, tokens: list[int]) -> str: pass diff --git a/serve/mlc_serve/model/tokenizer.py b/serve/mlc_serve/model/tokenizer.py index f7757d889d..7802d9f284 100644 --- a/serve/mlc_serve/model/tokenizer.py +++ b/serve/mlc_serve/model/tokenizer.py @@ -5,15 +5,28 @@ class Tokenizer: - def __init__(self, hf_tokenizer): + def __init__(self, hf_tokenizer, skip_special_tokens=True): self._tokenizer = hf_tokenizer self.eos_token_id = self._tokenizer.eos_token_id + self.skip_special_tokens = skip_special_tokens + self.all_special_ids = self._tokenizer.all_special_ids + self.is_fast = self._tokenizer.is_fast def encode(self, text: str) -> List[int]: return self._tokenizer.encode(text) - def decode(self, tokens: List[int]) -> str: - return self._tokenizer.decode(tokens, skip_special_tokens=True) + def decode(self, token_ids: List[int]) -> str: + return self._tokenizer.decode( + token_ids, skip_special_tokens=self.skip_special_tokens + ) + + def convert_ids_to_tokens(self, token_ids: List[int]) -> List[str]: + return self._tokenizer.convert_ids_to_tokens( + token_ids, skip_special_tokens=self.skip_special_tokens + ) + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + return self._tokenizer.convert_tokens_to_string(tokens) class ConversationTemplate: