diff --git a/serve/mlc_serve/engine/staging_engine.py b/serve/mlc_serve/engine/staging_engine.py index e60fb23a24..1578abbef8 100644 --- a/serve/mlc_serve/engine/staging_engine.py +++ b/serve/mlc_serve/engine/staging_engine.py @@ -83,7 +83,7 @@ def start(self): LOG.info("StagingInferenceEngine.start") try: self.worker_process.start() - if not self.ready_event.wait(timeout=90): + if not self.ready_event.wait(timeout=120): raise RuntimeError( "StagingInferenceEngine worker is not ready before timeout." ) diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index 583aea8f89..f5e4968dde 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -120,19 +120,23 @@ def add(self, request_states: list[RequestState]): # States which have been invalidated should never be added, directly # cancel them instead. valid_states = [] - for request_state in request_states: + kv_cache_size = self.cache_manager.get_kv_cache_size() + max_prompt_len = min(self.max_context_length, self.max_num_batched_tokens) + for state in request_states: if ( - request_state.validation_err is not None - or request_state.prompt_len > min(self.max_context_length, self.max_num_batched_tokens) - # Need to exclude requests which cannot fit into the kv_cache and can be processed - # at least max_decode_steps steps - or self.cache_manager.get_kv_cache_size() - request_state.prompt_len < self.max_decode_steps + state.validation_err is not None + or state.prompt_len > max_prompt_len + # We make sure that the KV cache will have enough free space for this request to proceed + # decoding for at least self.max_decode_steps steps. + or (kv_cache_size - state.prompt_len) < self.max_decode_steps ): - self.cancelled_requests.append(request_state) - if request_state.validation_err is None: - request_state.validation_err = ValidationError("The prompt is too long for the given set of engine parameters.") + self.cancelled_requests.append(state) + if state.validation_err is None: + state.validation_err = ValidationError( + "The prompt is too long for the given set of engine parameters." + ) else: - valid_states.append(request_state) + valid_states.append(state) self.queue.extend(valid_states) self.has_new_requests.notify_all() @@ -211,6 +215,7 @@ def step(self) -> GenerationLoopWorkerOutput: self.stopped_requests.clear() with self.queue_lock: + # Hold the lock here since self.cancelled_requests is modified in add(...) as well. for state in self.cancelled_requests: err = None if state.validation_err: @@ -328,20 +333,25 @@ def _adjust_batch(self): # self.max_num_batched_tokens. In such cases, we need to discard the recent decode # tokens that cannot fit into a batch, and recompute them after we fill the cache # entries for the older tokens. - if len(self.current_batch) == 0 and num_new_batched_tokens > self.max_num_batched_tokens: - state.token_ids = state.token_ids[:self.max_num_batched_tokens] - state.next_start_position = num_new_batched_tokens = num_tokens = self.max_num_batched_tokens - if num_new_batched_tokens > self.max_num_batched_tokens > 0: + if ( + len(self.current_batch) == 0 + and num_new_batched_tokens > self.max_num_batched_tokens + ): + state.token_ids = state.token_ids[: self.max_num_batched_tokens] + state.next_start_position = ( + num_new_batched_tokens + ) = num_tokens = self.max_num_batched_tokens + if num_new_batched_tokens > self.max_num_batched_tokens: LOG.debug( "Stop growing the batch due to max_num_batched_tokens. Batched tokens: %s", num_new_batched_tokens, ) break - # Make sure to leave some free space in the KV cache after a request is added or batched - if ( - (self.cache_manager.get_free_space() - num_tokens) / (len(self.current_batch) + 1) - < self.max_decode_steps - ): + # We make sure that the KV cache will have enough free space for all sequences in the batch + # to proceed decoding for at least self.max_decode_steps steps. + if (self.cache_manager.get_free_space() - num_tokens) / ( + len(self.current_batch) + 1 + ) < self.max_decode_steps: LOG.debug( "Stop growing the batch due to not enough free space. Free: %s, Num tokens: %s", self.cache_manager.get_free_space(), diff --git a/serve/mlc_serve/engine/sync_engine.py b/serve/mlc_serve/engine/sync_engine.py index adc961c15b..a58cf04e36 100644 --- a/serve/mlc_serve/engine/sync_engine.py +++ b/serve/mlc_serve/engine/sync_engine.py @@ -18,7 +18,7 @@ RequestState, SequenceOutput, check_stopping_sequences, - ValidationError + ValidationError, ) from .model_module import DecodeRequest, ModelModule, PrefillRequest, SequenceId, TextGenerator, Tokenizer as TokenizerP from ..model.base import ModelArtifactConfig @@ -96,8 +96,8 @@ def add(self, requests: list[Request]): if ( state.validation_err is not None or state.prompt_len > min(self.max_context_length, self.max_num_batched_tokens) - # Need to exclude requests which cannot fit into the kv_cache and can be processed - # at least max_decode_steps steps + # We make sure that the KV cache will have enough free space for this request to proceed + # decoding for at least self.max_decode_steps steps. or self.cache_manager.get_kv_cache_size() - state.prompt_len < self.max_decode_steps ): self.cancel(req.request_id) @@ -296,7 +296,8 @@ def _adjust_batch(self): num_new_batched_tokens, ) break - # Make sure to leave some free space in the KV cache after a request is added or batched + # We make sure that the KV cache will have enough free space for this request to proceed + # decoding for at least self.max_decode_steps steps. if ( (self.cache_manager.get_free_space() - num_tokens) / (len(self.current_batch) + 1) < self.max_decode_steps