Skip to content

Commit

Permalink
rebased
Browse files Browse the repository at this point in the history
  • Loading branch information
sunggg committed Feb 8, 2024
1 parent 323172f commit 982d751
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 21 deletions.
1 change: 0 additions & 1 deletion serve/mlc_serve/api/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def _get_sampling_params(
sampling_params.logprobs = request.logprobs
if request.response_format and request.response_format.type == "json_object":
sampling_params.json_schema = request.response_format.response_schema

sampling_params.vocab_size = model_artifact_config.vocab_size
return sampling_params

Expand Down
26 changes: 6 additions & 20 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,14 +263,7 @@ def get_requests_to_process(

if is_prompt_batch:
for state in current_states:
<<<<<<< HEAD
if is_evicted_parallel_sampling_request(state):
=======
if not state.is_prefilled:
# `JSONLogitsProcessor` needs to be created only once.
if state.sampling_params.json_schema is not None:
state.sampling_params.logits_processor = JSONLogitsProcessor(state.sampling_params.json_schema, tokenizer._tokenizer)
>>>>>>> works
requests.append(
PrefillRequest(
request_id=state.request_id,
Expand Down Expand Up @@ -299,6 +292,12 @@ def get_requests_to_process(
# TODO(masahi): How to account for token counts in EvalMultiQueryRequest in
# Prometheus metric?
elif not state.is_prefilled:
# `JSONLogitsProcessor` needs to be created only once.
if state.sampling_params.json_schema is not None:
state.sampling_params.logits_processor = JSONLogitsProcessor(
state.sampling_params.json_schema, tokenizer._tokenizer
)

if (
state.num_sequences == 1
and state.generation_sequences[0].generated_token_ids
Expand All @@ -317,14 +316,7 @@ def get_requests_to_process(
request_id=state.request_id,
token_ids=token_ids,
num_sequence=state.num_sequences,
<<<<<<< HEAD
sampling_params=state.sampling_params,
logit_processor=JSONLogitsProcessor(
state.sampling_params.json_schema, tokenizer._tokenizer
),
=======
sampling_params=state.sampling_params
>>>>>>> works
)
)

Expand All @@ -346,12 +338,6 @@ def get_requests_to_process(
prompt_token_counts=prompt_counts,
token_ids=gen_seq.generated_token_ids,
sampling_params=state.sampling_params,
<<<<<<< HEAD
logit_processor=JSONLogitsProcessor(
state.sampling_params.json_schema, tokenizer._tokenizer
),
=======
>>>>>>> works
)
)
cache_manager.extend(
Expand Down
12 changes: 12 additions & 0 deletions serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from ..engine.model_module import (
PrefillRequest,
DecodeRequest,
EvalMultiQueryRequest,
RequestType,
TextGenerationResult,
Expand Down Expand Up @@ -97,6 +98,17 @@ def sample_from_logits(
# synchronization point for sampling tensors
# wait until all the tensors are loaded on GPU
torch.cuda.current_stream().wait_stream(copy_stream)

# Logit processing for constraint sampling e.g., JSON Mode
for i, (sequence_id, request) in enumerate(zip(sequence_ids, requests)):
if request.sampling_params.logits_processor is not None:
cs_input_ids = (
request.token_ids if isinstance(request, DecodeRequest) else []
)
logits[i] = request.sampling_params.logits_processor(
sequence_id, cs_input_ids, logits[i]
)

logits = adjust_logits(logits, sampling_metadata, vocab_size)
outputs: List[TextGenerationResult] = []

Expand Down

0 comments on commit 982d751

Please sign in to comment.