Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Emoji w/o Perf Regression #107

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ class GenerationSequence:
generated_token_ids: list[int]
next_start_position: int
output_text: str
prefix_begin_offset: int = 0
prefix_end_offset: int = 0
prev_tokens: Optional[List[str]] = None
is_finished: bool = False


Expand Down
73 changes: 56 additions & 17 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from .model_module import (
DecodeRequest,
PrefillRequest,
Tokenizer,
ConversationTemplate,
KVCacheManager,
ModelModule,
Expand All @@ -33,7 +32,7 @@


def get_new_request_state(
request: Request, conversation_template: ConversationTemplate, tokenizer: Tokenizer
request: Request, conversation_template: ConversationTemplate, tokenizer: TokenizerP
) -> RequestState:
if request.debug_options.prompt is not None:
prompt = request.debug_options.prompt
Expand Down Expand Up @@ -68,28 +67,68 @@ def get_new_request_state(
)


def decode_last_output(
# Based on vllm: https://github.com/vllm-project/vllm/pull/984
def detokenize_incrementally(
prompt_tokens: list[int],
generation_sequence: GenerationSequence,
tokenizer: Tokenizer,
tokenizer: TokenizerP,
skip_special_tokens=False,
) -> str:
if len(generation_sequence.output_text):
prefix_idx = max(0, generation_sequence.next_start_position - 6)
new_token_id = generation_sequence.generated_token_ids[-1]

# This is the first iteration for this sequence
if generation_sequence.prev_tokens is None:
# TODO(masahi): Figure out a way to remove this concat
new_tokens = tokenizer.convert_ids_to_tokens(
prompt_tokens + generation_sequence.generated_token_ids
)
output_tokens = new_tokens

# 5 is an arbitrary value that should work for all
# tokenizers (bigger = more conservative).
# Subtract 1 extra to account for the generated token.
prefix_begin_offset = max(len(output_tokens) - 6, 0)

if skip_special_tokens and new_token_id in tokenizer.all_special_ids:
prefix_end_offset = max(len(output_tokens), 0)
else:
prefix_end_offset = max(len(output_tokens) - 1, 0)
else:
prefix_idx = generation_sequence.next_start_position
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer.convert_ids_to_tokens([new_token_id])
output_tokens = generation_sequence.prev_tokens + new_tokens

# TODO(masahi): Figure out a way to remove this concat
token_ids = prompt_tokens + generation_sequence.generated_token_ids
prefix_begin_offset = generation_sequence.prefix_begin_offset
prefix_end_offset = generation_sequence.prefix_end_offset

if prefix_idx == 0:
return tokenizer.decode(token_ids)
assert tokenizer.is_fast

prefix = tokenizer.decode(
token_ids[prefix_idx : generation_sequence.next_start_position]
prefix_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_begin_offset:prefix_end_offset]
)
full = tokenizer.decode(token_ids[prefix_idx:])
new_text = tokenizer.convert_tokens_to_string(output_tokens[prefix_begin_offset:])

if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
new_prefix_begin_offset = prefix_end_offset
new_prefix_end_offset = len(output_tokens)
delta = new_text[len(prefix_text) :]
else:
new_prefix_begin_offset = prefix_begin_offset
new_prefix_end_offset = prefix_end_offset
delta = ""

generation_sequence.prefix_begin_offset = new_prefix_begin_offset
generation_sequence.prefix_end_offset = new_prefix_end_offset
if generation_sequence.prev_tokens is None:
generation_sequence.prev_tokens = new_tokens
else:
generation_sequence.prev_tokens.extend(new_tokens)

return full[len(prefix) :]
return delta


def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended):
Expand All @@ -115,14 +154,14 @@ def update_sequence(
gen_seq: GenerationSequence,
new_token_ids: list[int],
prompt_token_ids: list[int],
tokenizer: Tokenizer,
tokenizer: TokenizerP,
stopping_criteria: StoppingCriteria,
) -> str:
gen_seq.next_start_position = len(prompt_token_ids) + len(
gen_seq.generated_token_ids
)
gen_seq.generated_token_ids.extend(new_token_ids)
delta = decode_last_output(prompt_token_ids, gen_seq, tokenizer)
delta = detokenize_incrementally(prompt_token_ids, gen_seq, tokenizer)
gen_seq.output_text += delta

gen_seq.output_text, delta, gen_seq.is_finished = check_stopping_sequences(
Expand Down
32 changes: 20 additions & 12 deletions serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Required interfaces for the actual inference capability in InferenceEngine.
"""
from dataclasses import dataclass
from typing import Optional, Protocol, Union
from typing import Optional, Protocol, Union, List

from .base import ChatMessage, RequestId, MLCServeEngineConfig, RequestState, SequenceId
from ..model.base import ModelArtifactConfig
Expand All @@ -12,7 +12,7 @@
@dataclass
class PrefillRequest:
request_id: RequestId
token_ids: list[int]
token_ids: List[int]
# Number of sequences to generate
num_sequence: int
sampling_params: SamplingParams
Expand All @@ -28,7 +28,7 @@ class PrefillRequest:
class DecodeRequest:
sequence_id: SequenceId
# All tokens for this request, including prompt
token_ids: list[int]
token_ids: List[int]
sampling_params: SamplingParams


Expand All @@ -41,7 +41,7 @@ class TextGenerationResult:
sequence_id: SequenceId
# for most cases, there should be only one token returned
# making this a list of token ids to leave room for speculative decoding
generated_tokens: list[int]
generated_tokens: List[int]
error: Optional[str]


Expand Down Expand Up @@ -116,9 +116,9 @@ class TextGenerator(Protocol):

def generate(
self,
requests: list[Union[PrefillRequest, DecodeRequest]],
requests: List[Union[PrefillRequest, DecodeRequest]],
kv_cache: KVCache,
) -> list[TextGenerationResult]:
) -> List[TextGenerationResult]:
"""
A unified entrypoint for text generation.

Expand All @@ -130,17 +130,25 @@ def generate(

class Tokenizer(Protocol):
eos_token_id: int
skip_special_tokens: bool
all_special_ids: List[int]
is_fast: bool

def encode(self, text: str) -> list[int]:
pass
def encode(self, text: str) -> List[int]:
...

# TODO: Incremental decoding
def decode(self, tokens: list[int]) -> str:
pass
def decode(self, tokens: List[int]) -> str:
...

def convert_ids_to_tokens(self, token_ids: List[int]) -> List[str]:
...

def convert_tokens_to_string(self, tokens: List[str]) -> str:
...


class ConversationTemplate(Protocol):
def apply(self, messages: list[ChatMessage]) -> str:
def apply(self, messages: List[ChatMessage]) -> str:
pass


Expand Down
19 changes: 16 additions & 3 deletions serve/mlc_serve/model/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,28 @@


class Tokenizer:
def __init__(self, hf_tokenizer):
def __init__(self, hf_tokenizer, skip_special_tokens=True):
self._tokenizer = hf_tokenizer
self.eos_token_id = self._tokenizer.eos_token_id
self.skip_special_tokens = skip_special_tokens
self.all_special_ids = self._tokenizer.all_special_ids
self.is_fast = self._tokenizer.is_fast

def encode(self, text: str) -> List[int]:
return self._tokenizer.encode(text)

def decode(self, tokens: List[int]) -> str:
return self._tokenizer.decode(tokens, skip_special_tokens=True)
def decode(self, token_ids: List[int]) -> str:
return self._tokenizer.decode(
token_ids, skip_special_tokens=self.skip_special_tokens
)

def convert_ids_to_tokens(self, token_ids: List[int]) -> List[str]:
return self._tokenizer.convert_ids_to_tokens(
token_ids, skip_special_tokens=self.skip_special_tokens
)

def convert_tokens_to_string(self, tokens: List[str]) -> str:
return self._tokenizer.convert_tokens_to_string(tokens)


class ConversationTemplate:
Expand Down