From 2b22f569bd1d4bc7f12b6ef5a29856f3d67be000 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 14 Mar 2024 20:34:27 +0000 Subject: [PATCH] changes from jared --- serve/mlc_serve/engine/engine_common.py | 192 +++++++++++++----------- serve/mlc_serve/engine/model_module.py | 6 + 2 files changed, 112 insertions(+), 86 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 918330f032..a46f82a2d6 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -5,7 +5,7 @@ import torch import time import json -from typing import Tuple, Deque, Dict, Optional, Callable, List +from typing import Any, Tuple, Deque, Dict, Optional, Callable, List from collections import deque from threading import Condition, Lock from pathlib import Path @@ -25,6 +25,7 @@ DecodeRequest, PrefillRequest, EvalMultiQueryRequest, + FailedRequest, EvictedTokens, ConversationTemplate, KVCacheManager, @@ -37,6 +38,7 @@ from ..openai_logprob_protocol import LogprobsContent, TopLogprobs from .constrained.fsm_cache import FSMCache from .constrained import build_regex_from_schema +from mlc_serve.errors import JSONModeError LOG = structlog.stdlib.get_logger(__name__) @@ -131,6 +133,8 @@ def detokenize_incrementally( prefix_end_offset = max(len(output_tokens) - 1, 0) else: # Put new_token_id in a list so skip_special_tokens is respected + + # TODO(@jroesch): guard here for out of bound token ids new_tokens = tokenizer.convert_ids_to_tokens([new_token_id]) output_tokens = generation_sequence.prev_tokens + new_tokens @@ -256,6 +260,18 @@ def prepare_output( return delta, out_logprob_info +# TODO(@jroesch): fix typing here +def _schema_to_regex_fsm(regex_fsm_cache, json_schema: Any) -> Any: + try: + # Convert schema into json string + json_schema = json.dumps(json_schema) + # Build a regex (grammar) from json string + json_regex = build_regex_from_schema(json_schema, whitespace_pattern=r"[ \n\t]?") + # Query fsm cache for FSM object + return regex_fsm_cache.query(json_regex) + except Exception as exc: + LOG.exception("An error occurred while building JSON mode FSM.", exc=exc) + raise JSONModeError("Failed to construct FSM.") from exc def get_requests_to_process( current_states: List[RequestState], @@ -281,101 +297,106 @@ def get_requests_to_process( if is_prompt_batch: for state in current_states: - if is_evicted_parallel_sampling_request(state): - requests.append( - PrefillRequest( - request_id=state.request_id, - token_ids=state.prompt_token_ids, - prompt_mask=state.prompt_mask, - num_sequence=state.num_sequences, - sampling_params=state.sampling_params, - ) - ) - - token_counts += len(state.prompt_token_ids) - - for gen_seq in state.generation_sequences: - # TODO(vvchernov): This is for repetion penalty - # Not obvious EvalMultiQueryRequest needs this - # Now empty instead of state.prompt_mask - vocab_size = state.sampling_params.vocab_size - prompt_mask = torch.zeros((vocab_size,), dtype=torch.bool) + try: + if is_evicted_parallel_sampling_request(state): requests.append( - EvalMultiQueryRequest( - sequence_id=gen_seq.seq_id, - num_past_tokens=state.prompt_len, - prompt_mask=prompt_mask, - queries=EvictedTokens(gen_seq.generated_token_ids), + PrefillRequest( + request_id=state.request_id, + token_ids=state.prompt_token_ids, + prompt_mask=state.prompt_mask, + num_sequence=state.num_sequences, sampling_params=state.sampling_params, ) ) - cache_manager.extend( - gen_seq.seq_id, - len(gen_seq.generated_token_ids) + 1, - ) - # TODO(masahi): How to account for token counts in EvalMultiQueryRequest in - # Prometheus metric? - elif not state.is_prefilled: - # Check if JSON mode is enabled - if state.sampling_params.json_schema is not None: - # Convert schema into json string - json_schema = json.dumps(state.sampling_params.json_schema) - # Build a regex (grammar) from json string - json_regex = build_regex_from_schema(json_schema, whitespace_pattern=r"[ \n\t]?") - # Query fsm cache for FSM object - state.sampling_params.regex_fsm = regex_fsm_cache.query(json_regex) - - if ( - state.num_sequences == 1 - and state.generation_sequences[0].generated_token_ids - ): - # generated_token_ids is added for the case where the request is - # recovering from cache eviction. - token_ids = ( - state.prompt_token_ids - + state.generation_sequences[0].generated_token_ids - ) - else: - token_ids = state.prompt_token_ids - - requests.append( - PrefillRequest( - request_id=state.request_id, - token_ids=token_ids, - prompt_mask=state.prompt_mask, - num_sequence=state.num_sequences, - sampling_params=state.sampling_params, - ) - ) + token_counts += len(state.prompt_token_ids) + + for gen_seq in state.generation_sequences: + # TODO(vvchernov): This is for repetition penalty + # Not obvious EvalMultiQueryRequest needs this + # Now empty instead of state.prompt_mask + vocab_size = state.sampling_params.vocab_size + prompt_mask = torch.zeros((vocab_size,), dtype=torch.bool) + requests.append( + EvalMultiQueryRequest( + sequence_id=gen_seq.seq_id, + num_past_tokens=state.prompt_len, + prompt_mask=prompt_mask, + queries=EvictedTokens(gen_seq.generated_token_ids), + sampling_params=state.sampling_params, + ) + ) + cache_manager.extend( + gen_seq.seq_id, + len(gen_seq.generated_token_ids) + 1, + ) - token_counts += len(state.prompt_token_ids) + # TODO(masahi): How to account for token counts in EvalMultiQueryRequest in + # Prometheus metric? + elif not state.is_prefilled: + # Check if JSON mode is enabled + if state.sampling_params.json_schema is not None: + state.sampling_params.regex_fsm = _schema_to_regex_fsm(regex_fsm_cache, state.sampling_params.json_schema) + if ( + state.num_sequences == 1 + and state.generation_sequences[0].generated_token_ids + ): + # generated_token_ids is added for the case where the request is + # recovering from cache eviction. + token_ids = ( + state.prompt_token_ids + + state.generation_sequences[0].generated_token_ids + ) + else: + token_ids = state.prompt_token_ids - LOG.debug( - "Creating prompt batch.", - num_requests=len(requests), - total_tokens=token_counts, - ) - else: - for state in current_states: - for gen_seq in state.generation_sequences: - if not gen_seq.is_finished: - prompt_counts = len(state.prompt_token_ids) requests.append( - DecodeRequest( - sequence_id=gen_seq.seq_id, - prompt_token_counts=prompt_counts, + PrefillRequest( + request_id=state.request_id, + token_ids=token_ids, prompt_mask=state.prompt_mask, - token_ids=gen_seq.generated_token_ids, + num_sequence=state.num_sequences, sampling_params=state.sampling_params, ) ) - cache_manager.extend( - gen_seq.seq_id, - prompt_counts - + len(gen_seq.generated_token_ids) - - gen_seq.next_start_position, - ) + + token_counts += len(state.prompt_token_ids) + except Exception as exc: + LOG.exception("An exception occurred creating internal request types.", request_id=state.request_id, exc=exc) + requests.append(FailedRequest( + request_id=state.request_id, + )) + LOG.debug( + "Creating prompt batch.", + num_requests=len(requests), + total_tokens=token_counts, + ) + else: + for state in current_states: + try: + for gen_seq in state.generation_sequences: + if not gen_seq.is_finished: + prompt_counts = len(state.prompt_token_ids) + requests.append( + DecodeRequest( + sequence_id=gen_seq.seq_id, + prompt_token_counts=prompt_counts, + prompt_mask=state.prompt_mask, + token_ids=gen_seq.generated_token_ids, + sampling_params=state.sampling_params, + ) + ) + cache_manager.extend( + gen_seq.seq_id, + prompt_counts + + len(gen_seq.generated_token_ids) + - gen_seq.next_start_position, + ) + except Exception as exc: + LOG.exception("An exception occurred creating internal request types.", request_id=state.request_id, exc=exc) + requests.append(FailedRequest( + request_id=state.request_id, + )) token_counts = len(requests) LOG.debug("Creating decode batch with %s requests.", token_counts) @@ -389,8 +410,7 @@ def should_stop_by_length( max_context_length: int, max_tokens: Optional[int], ) -> bool: - # If max_tokens is None, we do not put any length restriction. - if gen_seq.is_finished or max_tokens is None: + if gen_seq.is_finished: return False num_context_tokens = prompt_len + len(gen_seq.generated_token_ids) diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 302cd623ff..a97322ff3e 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -71,6 +71,12 @@ class EvalMultiQueryRequest: sampling_params: SamplingParams +@dataclass +class FailedRequest: + request_id: RequestId + error: Exception + + RequestType = Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]