diff --git a/serve/mlc_serve/api/handler.py b/serve/mlc_serve/api/handler.py index 3de28fe3b5..1ac19de062 100644 --- a/serve/mlc_serve/api/handler.py +++ b/serve/mlc_serve/api/handler.py @@ -73,7 +73,6 @@ def _get_sampling_params( sampling_params.logprobs = request.logprobs if request.response_format and request.response_format.type == "json_object": sampling_params.json_schema = request.response_format.response_schema - sampling_params.vocab_size = model_artifact_config.vocab_size return sampling_params diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 1fe341867e..acf1faa90d 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -263,14 +263,7 @@ def get_requests_to_process( if is_prompt_batch: for state in current_states: -<<<<<<< HEAD if is_evicted_parallel_sampling_request(state): -======= - if not state.is_prefilled: - # `JSONLogitsProcessor` needs to be created only once. - if state.sampling_params.json_schema is not None: - state.sampling_params.logits_processor = JSONLogitsProcessor(state.sampling_params.json_schema, tokenizer._tokenizer) ->>>>>>> works requests.append( PrefillRequest( request_id=state.request_id, @@ -299,6 +292,12 @@ def get_requests_to_process( # TODO(masahi): How to account for token counts in EvalMultiQueryRequest in # Prometheus metric? elif not state.is_prefilled: + # `JSONLogitsProcessor` needs to be created only once. + if state.sampling_params.json_schema is not None: + state.sampling_params.logits_processor = JSONLogitsProcessor( + state.sampling_params.json_schema, tokenizer._tokenizer + ) + if ( state.num_sequences == 1 and state.generation_sequences[0].generated_token_ids @@ -317,14 +316,7 @@ def get_requests_to_process( request_id=state.request_id, token_ids=token_ids, num_sequence=state.num_sequences, -<<<<<<< HEAD sampling_params=state.sampling_params, - logit_processor=JSONLogitsProcessor( - state.sampling_params.json_schema, tokenizer._tokenizer - ), -======= - sampling_params=state.sampling_params ->>>>>>> works ) ) @@ -346,12 +338,6 @@ def get_requests_to_process( prompt_token_counts=prompt_counts, token_ids=gen_seq.generated_token_ids, sampling_params=state.sampling_params, -<<<<<<< HEAD - logit_processor=JSONLogitsProcessor( - state.sampling_params.json_schema, tokenizer._tokenizer - ), -======= ->>>>>>> works ) ) cache_manager.extend( diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 1dda0cee94..d5a010e4e5 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -14,6 +14,7 @@ ) from ..engine.model_module import ( PrefillRequest, + DecodeRequest, EvalMultiQueryRequest, RequestType, TextGenerationResult, @@ -97,6 +98,17 @@ def sample_from_logits( # synchronization point for sampling tensors # wait until all the tensors are loaded on GPU torch.cuda.current_stream().wait_stream(copy_stream) + + # Logit processing for constraint sampling e.g., JSON Mode + for i, (sequence_id, request) in enumerate(zip(sequence_ids, requests)): + if request.sampling_params.logits_processor is not None: + cs_input_ids = ( + request.token_ids if isinstance(request, DecodeRequest) else [] + ) + logits[i] = request.sampling_params.logits_processor( + sequence_id, cs_input_ids, logits[i] + ) + logits = adjust_logits(logits, sampling_metadata, vocab_size) outputs: List[TextGenerationResult] = []