Skip to content

Commit

Permalink
mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
sunggg committed Dec 12, 2023
1 parent 26ca476 commit 845dfd5
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 16 deletions.
2 changes: 1 addition & 1 deletion serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ class GenerationSequence:
next_start_position: int
output_text: str
prefix_begin_offset: int = 0
new_prefix_end_offset: int = 0
prefix_end_offset: int = 0
prev_tokens: Optional[List[str]] = None
is_finished: bool = False

Expand Down
7 changes: 3 additions & 4 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from .model_module import (
DecodeRequest,
PrefillRequest,
Tokenizer,
ConversationTemplate,
KVCacheManager,
ModelModule,
Expand All @@ -33,7 +32,7 @@


def get_new_request_state(
request: Request, conversation_template: ConversationTemplate, tokenizer: Tokenizer
request: Request, conversation_template: ConversationTemplate, tokenizer: TokenizerP
) -> RequestState:
if request.debug_options.prompt is not None:
prompt = request.debug_options.prompt
Expand Down Expand Up @@ -72,7 +71,7 @@ def get_new_request_state(
def detokenize_incrementally(
prompt_tokens: list[int],
generation_sequence: GenerationSequence,
tokenizer: Tokenizer,
tokenizer: TokenizerP,
skip_special_tokens=False,
) -> str:
new_token_id = generation_sequence.generated_token_ids[-1]
Expand Down Expand Up @@ -155,7 +154,7 @@ def update_sequence(
gen_seq: GenerationSequence,
new_token_ids: list[int],
prompt_token_ids: list[int],
tokenizer: Tokenizer,
tokenizer: TokenizerP,
stopping_criteria: StoppingCriteria,
) -> str:
gen_seq.next_start_position = len(prompt_token_ids) + len(
Expand Down
31 changes: 20 additions & 11 deletions serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Required interfaces for the actual inference capability in InferenceEngine.
"""
from dataclasses import dataclass
from typing import Optional, Protocol, Union
from typing import Optional, Protocol, Union, List

from .base import ChatMessage, RequestId, MLCServeEngineConfig, RequestState, SequenceId
from ..model.base import ModelArtifactConfig
Expand All @@ -12,7 +12,7 @@
@dataclass
class PrefillRequest:
request_id: RequestId
token_ids: list[int]
token_ids: List[int]
# Number of sequences to generate
num_sequence: int
sampling_params: SamplingParams
Expand All @@ -28,7 +28,7 @@ class PrefillRequest:
class DecodeRequest:
sequence_id: SequenceId
# All tokens for this request, including prompt
token_ids: list[int]
token_ids: List[int]
sampling_params: SamplingParams


Expand All @@ -41,7 +41,7 @@ class TextGenerationResult:
sequence_id: SequenceId
# for most cases, there should be only one token returned
# making this a list of token ids to leave room for speculative decoding
generated_tokens: list[int]
generated_tokens: List[int]
error: Optional[str]


Expand Down Expand Up @@ -116,9 +116,9 @@ class TextGenerator(Protocol):

def generate(
self,
requests: list[Union[PrefillRequest, DecodeRequest]],
requests: List[Union[PrefillRequest, DecodeRequest]],
kv_cache: KVCache,
) -> list[TextGenerationResult]:
) -> List[TextGenerationResult]:
"""
A unified entrypoint for text generation.
Expand All @@ -130,16 +130,25 @@ def generate(

class Tokenizer(Protocol):
eos_token_id: int
skip_special_tokens: bool
all_special_ids: List[int]
is_fast: bool

def encode(self, text: str) -> list[int]:
pass
def encode(self, text: str) -> List[int]:
...

def decode(self, tokens: list[int]) -> str:
pass
def decode(self, tokens: List[int]) -> str:
...

def convert_ids_to_tokens(self, token_ids: List[int]) -> List[str]:
...

def convert_tokens_to_string(self, tokens: List[str]) -> str:
...


class ConversationTemplate(Protocol):
def apply(self, messages: list[ChatMessage]) -> str:
def apply(self, messages: List[ChatMessage]) -> str:
pass


Expand Down

0 comments on commit 845dfd5

Please sign in to comment.