Skip to content

Commit

Permalink
changes from jared
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Mar 14, 2024
1 parent 0caae1f commit 2b22f56
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 86 deletions.
192 changes: 106 additions & 86 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import time
import json
from typing import Tuple, Deque, Dict, Optional, Callable, List
from typing import Any, Tuple, Deque, Dict, Optional, Callable, List
from collections import deque
from threading import Condition, Lock
from pathlib import Path
Expand All @@ -25,6 +25,7 @@
DecodeRequest,
PrefillRequest,
EvalMultiQueryRequest,
FailedRequest,
EvictedTokens,
ConversationTemplate,
KVCacheManager,
Expand All @@ -37,6 +38,7 @@
from ..openai_logprob_protocol import LogprobsContent, TopLogprobs
from .constrained.fsm_cache import FSMCache
from .constrained import build_regex_from_schema
from mlc_serve.errors import JSONModeError

LOG = structlog.stdlib.get_logger(__name__)

Expand Down Expand Up @@ -131,6 +133,8 @@ def detokenize_incrementally(
prefix_end_offset = max(len(output_tokens) - 1, 0)
else:
# Put new_token_id in a list so skip_special_tokens is respected

# TODO(@jroesch): guard here for out of bound token ids
new_tokens = tokenizer.convert_ids_to_tokens([new_token_id])
output_tokens = generation_sequence.prev_tokens + new_tokens

Expand Down Expand Up @@ -256,6 +260,18 @@ def prepare_output(

return delta, out_logprob_info

# TODO(@jroesch): fix typing here
def _schema_to_regex_fsm(regex_fsm_cache, json_schema: Any) -> Any:
try:
# Convert schema into json string
json_schema = json.dumps(json_schema)
# Build a regex (grammar) from json string
json_regex = build_regex_from_schema(json_schema, whitespace_pattern=r"[ \n\t]?")
# Query fsm cache for FSM object
return regex_fsm_cache.query(json_regex)
except Exception as exc:
LOG.exception("An error occurred while building JSON mode FSM.", exc=exc)
raise JSONModeError("Failed to construct FSM.") from exc

def get_requests_to_process(
current_states: List[RequestState],
Expand All @@ -281,101 +297,106 @@ def get_requests_to_process(

if is_prompt_batch:
for state in current_states:
if is_evicted_parallel_sampling_request(state):
requests.append(
PrefillRequest(
request_id=state.request_id,
token_ids=state.prompt_token_ids,
prompt_mask=state.prompt_mask,
num_sequence=state.num_sequences,
sampling_params=state.sampling_params,
)
)

token_counts += len(state.prompt_token_ids)

for gen_seq in state.generation_sequences:
# TODO(vvchernov): This is for repetion penalty
# Not obvious EvalMultiQueryRequest needs this
# Now empty instead of state.prompt_mask
vocab_size = state.sampling_params.vocab_size
prompt_mask = torch.zeros((vocab_size,), dtype=torch.bool)
try:
if is_evicted_parallel_sampling_request(state):
requests.append(
EvalMultiQueryRequest(
sequence_id=gen_seq.seq_id,
num_past_tokens=state.prompt_len,
prompt_mask=prompt_mask,
queries=EvictedTokens(gen_seq.generated_token_ids),
PrefillRequest(
request_id=state.request_id,
token_ids=state.prompt_token_ids,
prompt_mask=state.prompt_mask,
num_sequence=state.num_sequences,
sampling_params=state.sampling_params,
)
)
cache_manager.extend(
gen_seq.seq_id,
len(gen_seq.generated_token_ids) + 1,
)

# TODO(masahi): How to account for token counts in EvalMultiQueryRequest in
# Prometheus metric?
elif not state.is_prefilled:
# Check if JSON mode is enabled
if state.sampling_params.json_schema is not None:
# Convert schema into json string
json_schema = json.dumps(state.sampling_params.json_schema)
# Build a regex (grammar) from json string
json_regex = build_regex_from_schema(json_schema, whitespace_pattern=r"[ \n\t]?")
# Query fsm cache for FSM object
state.sampling_params.regex_fsm = regex_fsm_cache.query(json_regex)

if (
state.num_sequences == 1
and state.generation_sequences[0].generated_token_ids
):
# generated_token_ids is added for the case where the request is
# recovering from cache eviction.
token_ids = (
state.prompt_token_ids
+ state.generation_sequences[0].generated_token_ids
)
else:
token_ids = state.prompt_token_ids

requests.append(
PrefillRequest(
request_id=state.request_id,
token_ids=token_ids,
prompt_mask=state.prompt_mask,
num_sequence=state.num_sequences,
sampling_params=state.sampling_params,
)
)
token_counts += len(state.prompt_token_ids)

for gen_seq in state.generation_sequences:
# TODO(vvchernov): This is for repetition penalty
# Not obvious EvalMultiQueryRequest needs this
# Now empty instead of state.prompt_mask
vocab_size = state.sampling_params.vocab_size
prompt_mask = torch.zeros((vocab_size,), dtype=torch.bool)
requests.append(
EvalMultiQueryRequest(
sequence_id=gen_seq.seq_id,
num_past_tokens=state.prompt_len,
prompt_mask=prompt_mask,
queries=EvictedTokens(gen_seq.generated_token_ids),
sampling_params=state.sampling_params,
)
)
cache_manager.extend(
gen_seq.seq_id,
len(gen_seq.generated_token_ids) + 1,
)

token_counts += len(state.prompt_token_ids)
# TODO(masahi): How to account for token counts in EvalMultiQueryRequest in
# Prometheus metric?
elif not state.is_prefilled:
# Check if JSON mode is enabled
if state.sampling_params.json_schema is not None:
state.sampling_params.regex_fsm = _schema_to_regex_fsm(regex_fsm_cache, state.sampling_params.json_schema)
if (
state.num_sequences == 1
and state.generation_sequences[0].generated_token_ids
):
# generated_token_ids is added for the case where the request is
# recovering from cache eviction.
token_ids = (
state.prompt_token_ids
+ state.generation_sequences[0].generated_token_ids
)
else:
token_ids = state.prompt_token_ids

LOG.debug(
"Creating prompt batch.",
num_requests=len(requests),
total_tokens=token_counts,
)
else:
for state in current_states:
for gen_seq in state.generation_sequences:
if not gen_seq.is_finished:
prompt_counts = len(state.prompt_token_ids)
requests.append(
DecodeRequest(
sequence_id=gen_seq.seq_id,
prompt_token_counts=prompt_counts,
PrefillRequest(
request_id=state.request_id,
token_ids=token_ids,
prompt_mask=state.prompt_mask,
token_ids=gen_seq.generated_token_ids,
num_sequence=state.num_sequences,
sampling_params=state.sampling_params,
)
)
cache_manager.extend(
gen_seq.seq_id,
prompt_counts
+ len(gen_seq.generated_token_ids)
- gen_seq.next_start_position,
)

token_counts += len(state.prompt_token_ids)
except Exception as exc:
LOG.exception("An exception occurred creating internal request types.", request_id=state.request_id, exc=exc)
requests.append(FailedRequest(
request_id=state.request_id,
))
LOG.debug(
"Creating prompt batch.",
num_requests=len(requests),
total_tokens=token_counts,
)
else:
for state in current_states:
try:
for gen_seq in state.generation_sequences:
if not gen_seq.is_finished:
prompt_counts = len(state.prompt_token_ids)
requests.append(
DecodeRequest(
sequence_id=gen_seq.seq_id,
prompt_token_counts=prompt_counts,
prompt_mask=state.prompt_mask,
token_ids=gen_seq.generated_token_ids,
sampling_params=state.sampling_params,
)
)
cache_manager.extend(
gen_seq.seq_id,
prompt_counts
+ len(gen_seq.generated_token_ids)
- gen_seq.next_start_position,
)
except Exception as exc:
LOG.exception("An exception occurred creating internal request types.", request_id=state.request_id, exc=exc)
requests.append(FailedRequest(
request_id=state.request_id,
))

token_counts = len(requests)
LOG.debug("Creating decode batch with %s requests.", token_counts)
Expand All @@ -389,8 +410,7 @@ def should_stop_by_length(
max_context_length: int,
max_tokens: Optional[int],
) -> bool:
# If max_tokens is None, we do not put any length restriction.
if gen_seq.is_finished or max_tokens is None:
if gen_seq.is_finished:
return False

num_context_tokens = prompt_len + len(gen_seq.generated_token_ids)
Expand Down
6 changes: 6 additions & 0 deletions serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ class EvalMultiQueryRequest:
sampling_params: SamplingParams


@dataclass
class FailedRequest:
request_id: RequestId
error: Exception


RequestType = Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]


Expand Down

0 comments on commit 2b22f56

Please sign in to comment.