Skip to content

Commit

Permalink
Fix for the case when max token is not set (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi authored Dec 11, 2023
1 parent 745ce71 commit a356074
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def update_sequence(
tokenizer: Tokenizer,
stopping_criteria: StoppingCriteria,
) -> str:
gen_seq.next_start_position = len(prompt_token_ids) + len(gen_seq.generated_token_ids)
gen_seq.next_start_position = len(prompt_token_ids) + len(
gen_seq.generated_token_ids
)
gen_seq.generated_token_ids.extend(new_token_ids)
delta = decode_last_output(prompt_token_ids, gen_seq, tokenizer)
gen_seq.output_text += delta
Expand Down Expand Up @@ -191,23 +193,22 @@ def should_stop_by_length(state: RequestState, max_context_length: int) -> bool:
# TODO: currently, we simply return true for both stopping reasons.
# in the future, we can differentiate these two.
# this include prompt tokens and gen tokens so far
if state.is_finished:
if state.is_finished or state.stopping_criteria.max_tokens is None:
return False

for gen_seq in state.generation_sequences:
if gen_seq.is_finished:
continue

num_context_tokens = state.prompt_len + len(gen_seq.generated_token_ids)

if num_context_tokens >= max_context_length:
gen_seq.is_finished = True
continue

num_gen_tokens = num_context_tokens - state.prompt_len
if (
state.stopping_criteria.max_tokens is not None
and num_gen_tokens < state.stopping_criteria.max_tokens
):

if num_gen_tokens < state.stopping_criteria.max_tokens:
return False

return True
Expand Down

0 comments on commit a356074

Please sign in to comment.