From d68566ba10af567deaf7efe25f83095bf6b7ac75 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 14 Mar 2024 21:08:42 +0000 Subject: [PATCH] Protect against invalid request format --- serve/mlc_serve/engine/staging_engine.py | 39 ++++++++++++++++++------ 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/serve/mlc_serve/engine/staging_engine.py b/serve/mlc_serve/engine/staging_engine.py index a457ff8385..f6ee03a751 100644 --- a/serve/mlc_serve/engine/staging_engine.py +++ b/serve/mlc_serve/engine/staging_engine.py @@ -57,6 +57,9 @@ def __init__( self.next_generation_output = None self.requests_lock = Lock() self.requests = dict[RequestId, RequestState]() + self.requests_to_be_cancelled_lock = Lock() + # Error message for each request that fails to be added to the engine + self.requests_to_be_cancelled = dict[RequestId, str]() # TODO(@team): This is a temporary solution to expose model config to higher API layer. # Follow-up with the proper solution @@ -119,13 +122,17 @@ def add(self, requests: list[Request]): assert isinstance(req.stopping_criteria.stop_sequences, list) # If the request violates the tokenization, this returns None, so skip. - state = get_new_request_state( - req, - self.conversation_template, - self.tokenizer, - self.model_artifact_config.vocab_size, - ) - new_request_states.append(state) + try: + state = get_new_request_state( + req, + self.conversation_template, + self.tokenizer, + self.model_artifact_config.vocab_size, + ) + new_request_states.append(state) + except Exception as e: + with self.requests_to_be_cancelled_lock: + self.requests_to_be_cancelled[req.request_id] = str(e) self.command_queue.put(AddRequestsCommand(request_states=new_request_states)) @@ -171,11 +178,25 @@ def step(self) -> InferenceStepResult: has_pending_requests=self.has_pending_requests(), ) + outputs = list[RequestOutput]() + + with self.requests_to_be_cancelled_lock: + if len(self.requests_to_be_cancelled) > 0: + for req_id, err_msg in self.requests_to_be_cancelled.items(): + outputs.append( + RequestOutput( + req_id, + sequences=[], + error=err_msg, + ) + ) + self.requests_to_be_cancelled.clear() + if not self._is_ready_to_serve(): raise RuntimeError("GenerationLoopWorker process is not running") if not self.has_pending_requests(): - return InferenceStepResult([]) + return InferenceStepResult(outputs) if self.next_generation_output is None: generation_output = self.result_queue.get() @@ -188,8 +209,6 @@ def step(self) -> InferenceStepResult: f"Error from GenerationLoopWorker process: {generation_output.error}" ) from generation_output.error - outputs = list[RequestOutput]() - with self.requests_lock: LOG.debug( "StagingInferenceEngine.step obtained requests_lock",