Skip to content

Commit

Permalink
done
Browse files Browse the repository at this point in the history
  • Loading branch information
sunggg committed Dec 12, 2023
1 parent f34811d commit 26ca476
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 68 deletions.
4 changes: 2 additions & 2 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ class GenerationSequence:
generated_token_ids: list[int]
next_start_position: int
output_text: str
prefix_offset: int = 0
read_offset: int = 0
prefix_begin_offset: int = 0
new_prefix_end_offset: int = 0
prev_tokens: Optional[List[str]] = None
is_finished: bool = False

Expand Down
85 changes: 23 additions & 62 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,66 +68,62 @@ def get_new_request_state(
)


def decode_last_output(
# Based on vllm: https://github.com/vllm-project/vllm/pull/984
def detokenize_incrementally(
prompt_tokens: list[int],
generation_sequence: GenerationSequence,
tokenizer: Tokenizer,
skip_special_tokens=True,
skip_special_tokens=False,
) -> str:
new_token_id = generation_sequence.generated_token_ids[-1]

# This is the first iteration for this sequence
if generation_sequence.prev_tokens is None:
# TODO(masahi): Figure out a way to remove this concat
all_input_ids = prompt_tokens + generation_sequence.generated_token_ids
new_tokens = tokenizer._tokenizer.convert_ids_to_tokens(
all_input_ids, skip_special_tokens=skip_special_tokens
new_tokens = tokenizer.convert_ids_to_tokens(
prompt_tokens + generation_sequence.generated_token_ids
)
output_tokens = new_tokens

# 5 is an arbitrary value that should work for all
# tokenizers (bigger = more conservative).
# Subtract 1 extra to account for the generated token.
prefix_offset = max(len(output_tokens) - 6, 0)
prefix_begin_offset = max(len(output_tokens) - 6, 0)

if skip_special_tokens and new_token_id in tokenizer._tokenizer.all_special_ids:
read_offset = max(len(output_tokens), 0)
if skip_special_tokens and new_token_id in tokenizer.all_special_ids:
prefix_end_offset = max(len(output_tokens), 0)
else:
read_offset = max(len(output_tokens) - 1, 0)
prefix_end_offset = max(len(output_tokens) - 1, 0)
else:
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer._tokenizer.convert_ids_to_tokens(
[new_token_id], skip_special_tokens=skip_special_tokens
)
new_tokens = tokenizer.convert_ids_to_tokens([new_token_id])
output_tokens = generation_sequence.prev_tokens + new_tokens

prefix_offset = generation_sequence.prefix_offset
read_offset = generation_sequence.read_offset
prefix_begin_offset = generation_sequence.prefix_begin_offset
prefix_end_offset = generation_sequence.prefix_end_offset

assert tokenizer._tokenizer.is_fast
assert tokenizer.is_fast

prefix_text = tokenizer._tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:read_offset]
)
new_text = tokenizer._tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:]
prefix_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_begin_offset:prefix_end_offset]
)
new_text = tokenizer.convert_tokens_to_string(output_tokens[prefix_begin_offset:])

if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
new_prefix_offset = read_offset
new_read_offset = len(output_tokens)
new_prefix_begin_offset = prefix_end_offset
new_prefix_end_offset = len(output_tokens)
delta = new_text[len(prefix_text) :]
else:
new_prefix_offset = prefix_offset
new_read_offset = read_offset
new_prefix_begin_offset = prefix_begin_offset
new_prefix_end_offset = prefix_end_offset
delta = ""

generation_sequence.prefix_offset = new_prefix_offset
generation_sequence.read_offset = new_read_offset
generation_sequence.prefix_begin_offset = new_prefix_begin_offset
generation_sequence.prefix_end_offset = new_prefix_end_offset
if generation_sequence.prev_tokens is None:
generation_sequence.prev_tokens = new_tokens
else:
Expand All @@ -136,41 +132,6 @@ def decode_last_output(
return delta


"""
def decode_last_output(
prompt_tokens: list[int],
generation_sequence: GenerationSequence,
tokenizer: Tokenizer,
) -> str:
if len(generation_sequence.output_text):
LOG.info(
f"T:{generation_sequence.output_text}({len({generation_sequence.output_text})})"
)
prefix_idx = max(0, generation_sequence.next_start_position - 6)
else:
LOG.info(
f"N: {generation_sequence.output_text}({len({generation_sequence.output_text})})"
)
prefix_idx = generation_sequence.next_start_position
# TODO(masahi): Figure out a way to remove this concat
token_ids = prompt_tokens + generation_sequence.generated_token_ids
LOG.info(
f"{generation_sequence.next_start_position}, {len(token_ids)}, {prefix_idx}"
)
if prefix_idx == 0:
return tokenizer.decode(token_ids)
prefix = tokenizer.decode(
token_ids[prefix_idx : generation_sequence.next_start_position]
)
full = tokenizer.decode(token_ids[prefix_idx:])
return full[len(prefix) :]
"""


def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended):
if stopping_criteria.stop_sequences:
for t in stopping_criteria.stop_sequences:
Expand Down Expand Up @@ -201,7 +162,7 @@ def update_sequence(
gen_seq.generated_token_ids
)
gen_seq.generated_token_ids.extend(new_token_ids)
delta = decode_last_output(prompt_token_ids, gen_seq, tokenizer)
delta = detokenize_incrementally(prompt_token_ids, gen_seq, tokenizer)
gen_seq.output_text += delta

gen_seq.output_text, delta, gen_seq.is_finished = check_stopping_sequences(
Expand Down
1 change: 0 additions & 1 deletion serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ class Tokenizer(Protocol):
def encode(self, text: str) -> list[int]:
pass

# TODO: Incremental decoding
def decode(self, tokens: list[int]) -> str:
pass

Expand Down
19 changes: 16 additions & 3 deletions serve/mlc_serve/model/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,28 @@


class Tokenizer:
def __init__(self, hf_tokenizer):
def __init__(self, hf_tokenizer, skip_special_tokens=True):
self._tokenizer = hf_tokenizer
self.eos_token_id = self._tokenizer.eos_token_id
self.skip_special_tokens = skip_special_tokens
self.all_special_ids = self._tokenizer.all_special_ids
self.is_fast = self._tokenizer.is_fast

def encode(self, text: str) -> List[int]:
return self._tokenizer.encode(text)

def decode(self, tokens: List[int]) -> str:
return self._tokenizer.decode(tokens, skip_special_tokens=True)
def decode(self, token_ids: List[int]) -> str:
return self._tokenizer.decode(
token_ids, skip_special_tokens=self.skip_special_tokens
)

def convert_ids_to_tokens(self, token_ids: List[int]) -> List[str]:
return self._tokenizer.convert_ids_to_tokens(
token_ids, skip_special_tokens=self.skip_special_tokens
)

def convert_tokens_to_string(self, tokens: List[str]) -> str:
return self._tokenizer.convert_tokens_to_string(tokens)


class ConversationTemplate:
Expand Down

0 comments on commit 26ca476

Please sign in to comment.