diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index f8b9dce3c1..ee89cd9749 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -272,7 +272,7 @@ class GenerationSequence: next_start_position: int output_text: str prefix_begin_offset: int = 0 - new_prefix_end_offset: int = 0 + prefix_end_offset: int = 0 prev_tokens: Optional[List[str]] = None is_finished: bool = False diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 09b7410fc3..92c1155305 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -20,7 +20,6 @@ from .model_module import ( DecodeRequest, PrefillRequest, - Tokenizer, ConversationTemplate, KVCacheManager, ModelModule, @@ -33,7 +32,7 @@ def get_new_request_state( - request: Request, conversation_template: ConversationTemplate, tokenizer: Tokenizer + request: Request, conversation_template: ConversationTemplate, tokenizer: TokenizerP ) -> RequestState: if request.debug_options.prompt is not None: prompt = request.debug_options.prompt @@ -72,7 +71,7 @@ def get_new_request_state( def detokenize_incrementally( prompt_tokens: list[int], generation_sequence: GenerationSequence, - tokenizer: Tokenizer, + tokenizer: TokenizerP, skip_special_tokens=False, ) -> str: new_token_id = generation_sequence.generated_token_ids[-1] @@ -155,7 +154,7 @@ def update_sequence( gen_seq: GenerationSequence, new_token_ids: list[int], prompt_token_ids: list[int], - tokenizer: Tokenizer, + tokenizer: TokenizerP, stopping_criteria: StoppingCriteria, ) -> str: gen_seq.next_start_position = len(prompt_token_ids) + len( diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index b546f6cf7b..8c8182b191 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -2,7 +2,7 @@ Required interfaces for the actual inference capability in InferenceEngine. """ from dataclasses import dataclass -from typing import Optional, Protocol, Union +from typing import Optional, Protocol, Union, List from .base import ChatMessage, RequestId, MLCServeEngineConfig, RequestState, SequenceId from ..model.base import ModelArtifactConfig @@ -12,7 +12,7 @@ @dataclass class PrefillRequest: request_id: RequestId - token_ids: list[int] + token_ids: List[int] # Number of sequences to generate num_sequence: int sampling_params: SamplingParams @@ -28,7 +28,7 @@ class PrefillRequest: class DecodeRequest: sequence_id: SequenceId # All tokens for this request, including prompt - token_ids: list[int] + token_ids: List[int] sampling_params: SamplingParams @@ -41,7 +41,7 @@ class TextGenerationResult: sequence_id: SequenceId # for most cases, there should be only one token returned # making this a list of token ids to leave room for speculative decoding - generated_tokens: list[int] + generated_tokens: List[int] error: Optional[str] @@ -116,9 +116,9 @@ class TextGenerator(Protocol): def generate( self, - requests: list[Union[PrefillRequest, DecodeRequest]], + requests: List[Union[PrefillRequest, DecodeRequest]], kv_cache: KVCache, - ) -> list[TextGenerationResult]: + ) -> List[TextGenerationResult]: """ A unified entrypoint for text generation. @@ -130,16 +130,25 @@ def generate( class Tokenizer(Protocol): eos_token_id: int + skip_special_tokens: bool + all_special_ids: List[int] + is_fast: bool - def encode(self, text: str) -> list[int]: - pass + def encode(self, text: str) -> List[int]: + ... - def decode(self, tokens: list[int]) -> str: - pass + def decode(self, tokens: List[int]) -> str: + ... + + def convert_ids_to_tokens(self, token_ids: List[int]) -> List[str]: + ... + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + ... class ConversationTemplate(Protocol): - def apply(self, messages: list[ChatMessage]) -> str: + def apply(self, messages: List[ChatMessage]) -> str: pass