diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index c306b13324..a2820b0bc5 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -194,7 +194,7 @@ def should_stop_by_length(state: RequestState, max_context_length: int) -> bool: # TODO: currently, we simply return true for both stopping reasons. # in the future, we can differentiate these two. # this include prompt tokens and gen tokens so far - if state.is_finished: + if state.is_finished or state.stopping_criteria.max_tokens is None: return False for gen_seq in state.generation_sequences: @@ -202,15 +202,14 @@ def should_stop_by_length(state: RequestState, max_context_length: int) -> bool: continue num_context_tokens = state.prompt_len + len(gen_seq.generated_token_ids) + if num_context_tokens >= max_context_length: gen_seq.is_finished = True continue num_gen_tokens = num_context_tokens - state.prompt_len - if ( - state.stopping_criteria.max_tokens is not None - and num_gen_tokens < state.stopping_criteria.max_tokens - ): + + if num_gen_tokens < state.stopping_criteria.max_tokens: return False return True