Skip to content

Commit

Permalink
wip:mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
sunggg committed Feb 5, 2024
1 parent d00fab3 commit e105f97
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
RequestState,
SequenceId,
StoppingCriteria,
RawLogprobsInfo,
)
from .model_module import (
DecodeRequest,
Expand Down Expand Up @@ -80,20 +81,19 @@ def detokenize_incrementally(
prompt_tokens: list[int],
generation_sequence: GenerationSequence,
tokenizer: TokenizerP,
new_token_id=None,
skip_special_tokens=False,
new_token_id: Optional[int] = None,
skip_special_tokens: bool = False,
) -> str:
# tokenizer.decode() is similar to doing convert_tokens_to_string(convert_ids_to_tokens(token_ids))
# in this function, we separate these two steps.
is_logprob = new_token_id is not None
new_token_id = (
generation_sequence.generated_token_ids[-1] if not is_logprob else new_token_id
)

# This is the first iteration for this sequence
if generation_sequence.prev_tokens is None:
# TODO(masahi): Figure out a way to remove this concat
new_tokens: Union[str, List[str]] = tokenizer.convert_ids_to_tokens(
new_tokens: List[str] = tokenizer.convert_ids_to_tokens( # type: ignore
prompt_tokens + generation_sequence.generated_token_ids
)
output_tokens = new_tokens
Expand All @@ -109,7 +109,7 @@ def detokenize_incrementally(
prefix_end_offset = max(len(output_tokens) - 1, 0)
else:
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens: Union[str, List[str]] = tokenizer.convert_ids_to_tokens(
new_tokens: List[str] = tokenizer.convert_ids_to_tokens( # type: ignore
[new_token_id]
)
output_tokens = generation_sequence.prev_tokens + new_tokens
Expand Down Expand Up @@ -170,12 +170,12 @@ def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended):


def prepare_logprob(
logprob_info,
delta,
logprob_info: Optional[List[RawLogprobsInfo]],
delta: str,
gen_seq: GenerationSequence,
prompt_token_ids,
tokenizer,
):
prompt_token_ids: List[int],
tokenizer: TokenizerP,
) -> List[Optional[LogprobsContent]]:
if logprob_info is None:
return []

Expand Down Expand Up @@ -219,15 +219,15 @@ def prepare_output(
logprob_info,
tokenizer: TokenizerP,
stopping_criteria: StoppingCriteria,
) -> str:
) -> Tuple[str, List[Optional[LogprobsContent]]]:
gen_seq.next_start_position = len(prompt_token_ids) + len(
gen_seq.generated_token_ids
)
gen_seq.generated_token_ids.extend(new_token_ids)
delta = detokenize_incrementally(prompt_token_ids, gen_seq, tokenizer)
gen_seq.output_text += delta

out_logprob_info = prepare_logprob(
out_logprob_info: List[Optional[LogprobsContent]] = prepare_logprob(
logprob_info, delta, gen_seq, prompt_token_ids, tokenizer
)

Expand Down

0 comments on commit e105f97

Please sign in to comment.