Skip to content

Commit

Permalink
Merge branch 'batch-serving' into parallel-sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 11, 2023
2 parents 1853a54 + 3774516 commit dafe391
Show file tree
Hide file tree
Showing 7 changed files with 312 additions and 82 deletions.
3 changes: 2 additions & 1 deletion serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing import List, Callable, Any, Optional, Dict
import inspect

from .streamer import TextStreamer
from .sampling_params import SamplingParams, SamplingType

RequestId = str
Expand Down Expand Up @@ -271,6 +271,7 @@ class GenerationSequence:
generated_token_ids: list[int]
next_start_position: int
output_text: str
text_streamer: TextStreamer
is_finished: bool = False


Expand Down
30 changes: 4 additions & 26 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Tokenizer as TokenizerP,
)
from ..model.base import ModelArtifactConfig
from .streamer import TextStreamer

LOG = structlog.stdlib.get_logger(__name__)

Expand All @@ -52,6 +53,7 @@ def get_new_request_state(
generated_token_ids=[],
next_start_position=0,
output_text="",
text_streamer=TextStreamer(tokenizer),
)
for i in range(request.num_sequences)
]
Expand All @@ -68,30 +70,6 @@ def get_new_request_state(
)


def decode_last_output(
prompt_tokens: list[int],
generation_sequence: GenerationSequence,
tokenizer: Tokenizer,
) -> str:
if len(generation_sequence.output_text):
prefix_idx = max(0, generation_sequence.next_start_position - 6)
else:
prefix_idx = generation_sequence.next_start_position

# TODO(masahi): Figure out a way to remove this concat
token_ids = prompt_tokens + generation_sequence.generated_token_ids

if prefix_idx == 0:
return tokenizer.decode(token_ids)

prefix = tokenizer.decode(
token_ids[prefix_idx : generation_sequence.next_start_position]
)
full = tokenizer.decode(token_ids[prefix_idx:])

return full[len(prefix) :]


def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended):
if stopping_criteria.stop_sequences:
for t in stopping_criteria.stop_sequences:
Expand All @@ -115,12 +93,12 @@ def update_sequence(
gen_seq: GenerationSequence,
new_token_ids: list[int],
prompt_token_ids: list[int],
tokenizer: Tokenizer,
stopping_criteria: StoppingCriteria,
) -> str:
gen_seq.next_start_position = len(prompt_token_ids) + len(gen_seq.generated_token_ids)
gen_seq.generated_token_ids.extend(new_token_ids)
delta = decode_last_output(prompt_token_ids, gen_seq, tokenizer)

delta = gen_seq.text_streamer.put([new_token_ids[-1]])
gen_seq.output_text += delta

gen_seq.output_text, delta, gen_seq.is_finished = check_stopping_sequences(
Expand Down
2 changes: 0 additions & 2 deletions serve/mlc_serve/engine/staging_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
ShutdownCommand,
run_generation_loop_worker,
)

from ..logging_utils import log_every

LOG = structlog.stdlib.get_logger(__name__)
Expand Down Expand Up @@ -225,7 +224,6 @@ def step(self) -> InferenceStepResult:
gen_seq,
new_token_ids,
state.prompt_token_ids,
self.tokenizer,
state.stopping_criteria,
)

Expand Down
95 changes: 95 additions & 0 deletions serve/mlc_serve/engine/streamer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import List, Deque
from collections import deque

kReplacementCharacter = b"\xef\xbf\xbd".decode("utf8")


class TextStreamer:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.prefix_tokens: List[int] = []
self.pending_tokens: Deque[int] = deque([])

def put(self, delta_tokens: List[int]) -> str:
if len(delta_tokens) == 0:
return ""

ret = ""
for delta_token in delta_tokens:
self.pending_tokens.append(delta_token)
all_tokens = self.prefix_tokens + list(self.pending_tokens)

prefix_str = (
self.tokenizer.decode(self.prefix_tokens)
if len(self.prefix_tokens) > 0
else ""
)
full_str = self.tokenizer.decode(all_tokens)
prefix_len = len(prefix_str)

new_pending_tokens: Deque[int] = deque([])
if full_str[:prefix_len] == prefix_str:
# Case 1. prefix_str is a prefix of `full_str`.
# We cannot naively do `validated_str = self.tokenizer.decode(validated_tokens)`
# since it will lose the contextual information, such as ' '.
validated_str = full_str[prefix_len:]
while (
len(self.pending_tokens) > 0
and len(new_pending_tokens) < 3
and len(validated_str) >= 1
and validated_str[len(validated_str) - 1 :] == kReplacementCharacter
):
new_pending_tokens.appendleft(self.pending_tokens.pop())
validated_str = validated_str[: len(validated_str) - 1]
else:
# Case 2. prefix_str is not a prefix of `full_str`.
# Pop pending tokens from the back.
# - Pop until prefix_str is indeed a prefix of full_str.
# - A valid UTF-8 has 4 chars at most.
# So there will be at most 3 tokens popped.
# - If there are no more than 3 pending tokens, skip popping.
# This is because it is impossible to make full_str contain
# prefix_str without popping all the pending tokens.
if len(self.pending_tokens) < 3:
continue
get_valid_full_str = False
while len(self.pending_tokens) > 0 and len(new_pending_tokens) < 3:
new_pending_tokens.appendleft(self.pending_tokens.pop())
all_tokens.pop()
full_str = self.tokenizer.decode(all_tokens)
if full_str[:prefix_len] == prefix_str:
get_valid_full_str = True
break
if get_valid_full_str:
# We find a full_str which starts from prefix_str
# So we return the sliced full string without the prefix.
validated_str = full_str[prefix_len:]
else:
# We cannot find a full_str which starts from prefix_str by
# popping 3 tokens.
# In this case, the remaining pending tokens are invalid UTF-8
# characters already, so we return the decoded pending tokens.
validated_str = self.tokenizer.decode(self.pending_tokens)

if len(self.pending_tokens) > 0:
# set the new prefix
self.prefix_tokens = list(self.pending_tokens)
self.pending_tokens = new_pending_tokens

ret += validated_str
return ret

def finish(self) -> str:
all_tokens = self.prefix_tokens + list(self.pending_tokens)
prefix_str = (
self.tokenizer.decode(self.prefix_tokens)
if len(self.prefix_tokens) > 0
else ""
)
full_str = self.tokenizer.decode(all_tokens) if len(all_tokens) > 0 else ""
prefix_len = len(prefix_str)

if full_str[:prefix_len] == prefix_str:
return full_str[prefix_len:]
else:
return self.tokenizer.decode(self.pending_tokens)
1 change: 0 additions & 1 deletion serve/mlc_serve/engine/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ def step(self) -> InferenceStepResult:
gen_seq,
new_token_ids,
state.prompt_token_ids,
self.tokenizer,
state.stopping_criteria,
)

Expand Down
Loading

0 comments on commit dafe391

Please sign in to comment.