diff --git a/serve/benchmarks/benchmark_throughput.py b/serve/benchmarks/benchmark_throughput.py index c12ab362a1..8d38f4ed0c 100644 --- a/serve/benchmarks/benchmark_throughput.py +++ b/serve/benchmarks/benchmark_throughput.py @@ -104,7 +104,6 @@ def create_engine_and_tokenizer_module( "max_input_len": args.max_input_len, "min_decode_steps": args.min_decode_steps, "max_decode_steps": args.max_decode_steps, - "prompt_allocate_ratio": args.prompt_allocate_ratio }) if args.use_staging_engine: @@ -182,7 +181,6 @@ def main(args: argparse.Namespace): parser.add_argument("--max-input-len", type=int, default=512) parser.add_argument("--min-decode-steps", type=int, default=32) parser.add_argument("--max-decode-steps", type=int, default=56) - parser.add_argument("--prompt-allocate-ratio", type=float, default=2.0) parser.add_argument( "--num-prompts", type=int, default=1000, help="Number of prompts to process." ) diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index 43a9f5244e..865186c025 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -27,7 +27,6 @@ class MLCServeEngineConfig: max_num_batched_tokens: int = -1 min_decode_steps: int = 32 max_decode_steps: int = 48 - prompt_allocate_ratio: float = 2.0 @classmethod def _from_json(config_cls, json_obj: Dict[Any, Any]): @@ -39,30 +38,27 @@ def _from_json(config_cls, json_obj: Dict[Any, Any]): } ) -def get_engine_config(dict_config, enable_check = True): +def get_engine_config(dict_config): engine_config = MLCServeEngineConfig._from_json(dict_config) # Checks to make sure engine configs are set correctly # since engine config is critical to the performance - if enable_check: - assert isinstance(engine_config.use_staging_engine, bool) - assert isinstance(engine_config.max_num_batched_tokens, int) - assert isinstance(engine_config.max_input_len, int) - assert isinstance(engine_config.max_num_sequences, int) - assert isinstance(engine_config.max_decode_steps, int) - assert isinstance(engine_config.min_decode_steps, int) - assert isinstance(engine_config.prompt_allocate_ratio, float) - - # TODO(@sunggg): engine allows -1 for these params. figure out the behavior and enable checks properly - assert engine_config.max_num_batched_tokens == -1, \ - "`max_num_batched_tokens` is not supposed to be configured directly. \ - Use `max_num_sequences` and `max_input_len` instead." - assert engine_config.max_input_len > 0 - assert engine_config.max_num_sequences > 0 - engine_config.max_num_batched_tokens = engine_config.max_num_sequences * engine_config.max_input_len - - assert (engine_config.min_decode_steps > 0) and (engine_config.max_decode_steps > 0) - assert engine_config.max_decode_steps > engine_config.min_decode_steps - assert engine_config.prompt_allocate_ratio > 0 + assert isinstance(engine_config.use_staging_engine, bool) + assert isinstance(engine_config.max_num_batched_tokens, int) + assert isinstance(engine_config.max_input_len, int) + assert isinstance(engine_config.max_num_sequences, int) + assert isinstance(engine_config.max_decode_steps, int) + assert isinstance(engine_config.min_decode_steps, int) + + # TODO(@sunggg): engine allows -1 for these params. figure out the behavior and enable checks properly + assert engine_config.max_num_batched_tokens == -1, \ + "`max_num_batched_tokens` is not supposed to be configured directly. \ + Use `max_num_sequences` and `max_input_len` instead." + assert engine_config.max_input_len > 0 + assert engine_config.max_num_sequences > 0 + engine_config.max_num_batched_tokens = engine_config.max_num_sequences * engine_config.max_input_len + + assert (engine_config.min_decode_steps > 0) and (engine_config.max_decode_steps > 0) + assert engine_config.max_decode_steps > engine_config.min_decode_steps return engine_config diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index dce55c3121..8614675c82 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -65,7 +65,6 @@ class GenerationLoopWorker: max_num_batched_tokens: int max_decode_steps: int min_decode_steps: int - prompt_allocate_ratio: float queue_lock: Lock queue: Deque[RequestState] has_new_requests: Condition @@ -92,8 +91,6 @@ def __init__( self.min_decode_steps = min( self.max_decode_steps - 1, model_module.engine_config.min_decode_steps ) - self.prompt_allocate_ratio = model_module.engine_config.prompt_allocate_ratio - assert self.prompt_allocate_ratio >= 1.0 self.queue_lock = Lock() self.queue = deque[RequestState]() @@ -113,9 +110,14 @@ def add(self, request_states: list[RequestState]): for request_state in request_states: if ( request_state.validation_err is not None - or request_state.prompt_len >= self.max_context_length + or request_state.prompt_len > min(self.max_context_length, self.max_num_batched_tokens) + # Need to exclude requests which cannot fit into the kv_cache and can be processed + # at least max_decode_steps steps + or self.cache_manager.get_kv_cache_size() - request_state.prompt_len < self.max_decode_steps ): self.cancelled_requests.append(request_state) + if request_state.validation_err is None: + request_state.validation_err = "The prompt is too long for the given set of engine parameters." else: valid_states.append(request_state) @@ -194,24 +196,25 @@ def step(self) -> GenerationLoopWorkerOutput: self.stopped_requests.clear() - for state in self.cancelled_requests: - err = None - if state.validation_err: - err = state.validation_err + with self.queue_lock: + for state in self.cancelled_requests: + err = None + if state.validation_err: + err = state.validation_err - outputs.append( - SequenceGenerationOutput( - # TODO: support multi-sequence - id=SequenceId(state.request_id, 0), - new_tokens=[], - finish_reason=FinishReason.Cancelled, - error=err, + outputs.append( + SequenceGenerationOutput( + # TODO: support multi-sequence + id=SequenceId(state.request_id, 0), + new_tokens=[], + finish_reason=FinishReason.Cancelled, + error=err, + ) ) - ) - if state.request_id in self.current_batch: - self._remove_request_from_batch(state.request_id) + if state.request_id in self.current_batch: + self._remove_request_from_batch(state.request_id) - self.cancelled_requests.clear() + self.cancelled_requests.clear() self._adjust_batch() @@ -297,15 +300,24 @@ def _adjust_batch(self): state = self.queue[0] num_tokens = len(state.token_ids) num_new_batched_tokens += num_tokens + # This can happen when we are recovering from cache eviction and the sum of prompt + # and intermediate decode tokens is bigger than the biggest allowable batch size, + # self.max_num_batched_tokens. In such cases, we need to discard the recent decode + # tokens that cannot fit into a batch, and recompute them after we fill the cache + # entries for the older tokens. + if len(self.current_batch) == 0 and num_new_batched_tokens > self.max_num_batched_tokens: + state.token_ids = state.token_ids[:self.max_num_batched_tokens] + state.next_start_position = num_new_batched_tokens = num_tokens = self.max_num_batched_tokens if num_new_batched_tokens > self.max_num_batched_tokens > 0: LOG.debug( "Stop growing the batch due to max_num_batched_tokens. Batched tokens: %s", num_new_batched_tokens, ) break + # Make sure to leave some free space in the KV cache after a request is added or batched if ( - self.cache_manager.get_free_space() - <= self.prompt_allocate_ratio * num_tokens + (self.cache_manager.get_free_space() - num_tokens) / (len(self.current_batch) + 1) + < self.max_decode_steps ): LOG.debug( "Stop growing the batch due to not enough free space. Free: %s, Num tokens: %s", diff --git a/serve/mlc_serve/engine/sync_engine.py b/serve/mlc_serve/engine/sync_engine.py index 8793797edf..f7c8399282 100644 --- a/serve/mlc_serve/engine/sync_engine.py +++ b/serve/mlc_serve/engine/sync_engine.py @@ -41,7 +41,6 @@ class SynchronousInferenceEngine(InferenceEngine): max_num_batched_tokens: int max_decode_steps: int min_decode_steps: int - prompt_allocate_ratio: float queue_lock: Lock queue: Deque[RequestState] has_new_requests: Condition @@ -60,6 +59,7 @@ def __init__( assert self.model_artifact_config.max_context_length, "max_context_length must not be zero" self.max_context_length = self.model_artifact_config.max_context_length self.max_num_batched_tokens = model_module.engine_config.max_num_batched_tokens + assert self.max_num_batched_tokens > 0, "max_num_batched_tokens must be positive" self.max_decode_steps = min( self.cache_manager.get_kv_cache_size(), model_module.engine_config.max_decode_steps, @@ -67,8 +67,6 @@ def __init__( self.min_decode_steps = min( self.max_decode_steps - 1, model_module.engine_config.min_decode_steps ) - self.prompt_allocate_ratio = model_module.engine_config.prompt_allocate_ratio - assert self.prompt_allocate_ratio >= 1.0 self.queue_lock = Lock() self.queue = deque[RequestState]() @@ -98,8 +96,16 @@ def add(self, requests: list[Request]): state = self._get_new_request_state(req) new_request_states.append(state) - if state.prompt_len >= self.max_context_length: + if ( + state.validation_err is not None + or state.prompt_len > min(self.max_context_length, self.max_num_batched_tokens) + # Need to exclude requests which cannot fit into the kv_cache and can be processed + # at least max_decode_steps steps + or self.cache_manager.get_kv_cache_size() - state.prompt_len < self.max_decode_steps + ): self.cancel(req.request_id) + if state.validation_err is None: + state.validation_err = "The prompt is too long for the given set of engine parameters." with self.queue_lock: self.queue.extend(new_request_states) @@ -279,15 +285,24 @@ def _adjust_batch(self): state = self.queue[0] num_tokens = len(state.token_ids) num_new_batched_tokens += num_tokens + # This can happen when we are recovering from cache eviction and the sum of prompt + # and intermediate decode tokens is bigger than the biggest allowable batch size, + # self.max_num_batched_tokens. In such cases, we need to discard the recent decode + # tokens that cannot fit into a batch, and recompute them after we fill the cache + # entries for the older tokens. + if not len(self.current_batch) and num_new_batched_tokens > self.max_num_batched_tokens: + state.token_ids = state.token_ids[:self.max_num_batched_tokens] + state.next_start_position = num_new_batched_tokens = num_tokens = self.max_num_batched_tokens if num_new_batched_tokens > self.max_num_batched_tokens > 0: logger.debug( "Stop growing the batch due to max_num_batched_tokens. Batched tokens: %s", num_new_batched_tokens, ) break + # Make sure to leave some free space in the KV cache after a request is added or batched if ( - self.cache_manager.get_free_space() - <= self.prompt_allocate_ratio * num_tokens + (self.cache_manager.get_free_space() - num_tokens) / (len(self.current_batch) + 1) + < self.max_decode_steps ): logger.debug( "Stop growing the batch due to not enough free space. Free: %s, Num tokens: %s", diff --git a/serve/mlc_serve/model/dummy_model.py b/serve/mlc_serve/model/dummy_model.py new file mode 100644 index 0000000000..f81a06c515 --- /dev/null +++ b/serve/mlc_serve/model/dummy_model.py @@ -0,0 +1,138 @@ +from typing import Optional, Union + +from mlc_serve.engine import ( + ChatMessage, + DebugOptions, + FinishReason, + Request, + RequestId, + RequestOutput, + SamplingParams, + StoppingCriteria, + get_engine_config +) +from mlc_serve.model.base import ModelArtifactConfig +from mlc_serve.engine.model_module import ( + ConversationTemplate, + DecodeRequest, + KVCache, + KVCacheManager, + ModelModule, + PrefillRequest, + SequenceId, + TextGenerationResult, + TextGenerator, + Tokenizer, +) + +class DummyTokenizer: + @property + def eos_token_id(self): + return 2 + + def encode(self, text: str, **kwargs) -> list[int]: + return [1] * len(text.split()) + + def decode(self, tokens: list[int], **kwargs) -> str: + return "test " * len(tokens) + + +class DummyConversationTemplate: + def apply(self, messages: list[ChatMessage]) -> str: + return " ".join(m.content for m in messages if m.content is not None) + + +class DummyCache: + def __init__(self, max_cached_tokens: int): + self.max_cached_tokens = max_cached_tokens + self.cached_requests = dict[RequestId, int]() + + +class DummyCacheManager: + def __init__(self, max_cached_tokens: int): + self.cache = DummyCache(max_cached_tokens) + + def get_cache(self) -> KVCache: + return self.cache + + def allocate(self, request_id: RequestId, num_tokens: int) -> bool: + self.cache.cached_requests[request_id] = num_tokens + if self.get_free_space() < 0: + raise RuntimeError("Cache out of space") + + def extend(self, sequence_id: SequenceId, new_tokens: int) -> bool: + if sequence_id.sequence_index > 0: + raise RuntimeError("Multiple generated sequences not supported") + self.cache.cached_requests[sequence_id.request_id] += new_tokens + if self.get_free_space() < 0: + raise RuntimeError("Cache out of space") + + def free(self, sequence_id: SequenceId): + if sequence_id.sequence_index > 0: + raise RuntimeError("Multiple generated sequences not supported") + del self.cache.cached_requests[sequence_id.request_id] + + def get_kv_cache_size(self) -> int: + return self.cache.max_cached_tokens + + def get_free_space(self) -> int: + return self.cache.max_cached_tokens - sum(self.cache.cached_requests.values()) + + def get_max_new_tokens(self) -> int: + if not self.cache.cached_requests: + return self.get_kv_cache_size() + return self.get_free_space() // len(self.cache.cached_requests) + + +class DummyTextGenerator: + def generate( + self, + requests: list[Union[PrefillRequest, DecodeRequest]], + kv_cache: KVCache, + ) -> list[TextGenerationResult]: + result = [] + for req in requests: + if isinstance(req, DecodeRequest): + request_id = req.sequence_id.request_id + if req.sequence_id.sequence_index > 0: + raise RuntimeError("Multiple generated sequences not supported") + else: + request_id = req.request_id + + if len(req.token_ids) > kv_cache.cached_requests[request_id]: + raise RuntimeError(f"Cache out of space for request {req.request_id}") + result.append( + TextGenerationResult( + sequence_id=SequenceId( + request_id=request_id, + sequence_index=0, + ), + generated_tokens=[1], + error=None, + ) + ) + return result + + +class DummyModelModule: + def __init__(self, max_cached_tokens: int, max_input_len = 512, max_num_sequences = 8): + self.tokenizer = DummyTokenizer() + self.conversation_template = DummyConversationTemplate() + self.text_generator = DummyTextGenerator() + self.cache_manager = DummyCacheManager(max_cached_tokens) + self.model_artifact_config = ModelArtifactConfig._from_json({ + "max_context_length": 1024, + }) + self.engine_config = get_engine_config({ + "max_decode_steps": 2, + "min_decode_steps": 1, + "use_staging_engine" : False, + "max_input_len": max_input_len, + "max_num_sequences": max_num_sequences + }) + + +class DummyTokenizerModule: + def __init__(self): + self.tokenizer = DummyTokenizer() + self.conversation_template = DummyConversationTemplate() \ No newline at end of file diff --git a/serve/mlc_serve/run.py b/serve/mlc_serve/run.py index d220de6143..3516fea02e 100644 --- a/serve/mlc_serve/run.py +++ b/serve/mlc_serve/run.py @@ -36,7 +36,6 @@ def parse_args(): args.add_argument("--max-input-len", type=int, default=512) args.add_argument("--min-decode-steps", type=int, default=12) args.add_argument("--max-decode-steps", type=int, default=16) - args.add_argument("--prompt-allocate-ratio", type=float, default=2.0) args.add_argument("--debug-logging", action="store_true") parsed = args.parse_args() return parsed @@ -64,7 +63,6 @@ def create_engine( "max_input_len": args.max_input_len, "min_decode_steps": args.min_decode_steps, "max_decode_steps": args.max_decode_steps, - "prompt_allocate_ratio": args.prompt_allocate_ratio }) if args.use_staging_engine: diff --git a/serve/tests/test_engine.py b/serve/tests/test_engine.py index 4e8acc2d83..947656110c 100644 --- a/serve/tests/test_engine.py +++ b/serve/tests/test_engine.py @@ -20,7 +20,7 @@ from mlc_serve.logging_utils import configure_logging -def test(args: argparse.Namespace): +def _test(args: argparse.Namespace): # Examples. "--max-output-len" can be used to specify the number of output tokens. # # Profile the gpu memory usage, and use the maximum number of cache blocks possible: @@ -39,7 +39,6 @@ def test(args: argparse.Namespace): "max_input_len": args.max_input_len, "min_decode_steps": args.min_decode_steps, "max_decode_steps": args.max_decode_steps, - "prompt_allocate_ratio": args.prompt_allocate_ratio, } ) @@ -127,7 +126,6 @@ def test(args: argparse.Namespace): parser.add_argument("--max-input-len", type=int, default=512) parser.add_argument("--max-num-sequences", type=int, default=8) parser.add_argument("--max-output-len", type=int, default=20) - parser.add_argument("--prompt-allocate-ratio", type=float, default=2.0) parser.add_argument("--long-prompt", action="store_true") parser.add_argument("--use-random-sampling", action="store_true") parser.add_argument("--use-staging-engine", action="store_true") @@ -148,6 +146,7 @@ def test(args: argparse.Namespace): torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) + configure_logging(enable_json_logs=False, log_level="INFO") - test(args) + _test(args) diff --git a/serve/tests/unittest/test_engine_init.py b/serve/tests/unittest/test_engine_init.py index f4f493b224..a52328b646 100644 --- a/serve/tests/unittest/test_engine_init.py +++ b/serve/tests/unittest/test_engine_init.py @@ -6,7 +6,7 @@ from mlc_serve.model.paged_cache_model import PagedCacheModelModule -def test_insufficient_cache_blocks_fail(artifact_path): +def _test_insufficient_cache_blocks_fail(artifact_path): model_artifact_path = os.path.join(artifact_path, "codellama-13b-instruct-hf-q0f16") if not os.path.exists(os.path.join(model_artifact_path)): @@ -20,7 +20,6 @@ def try_init(max_num_seqs): "max_input_len": 16384, "min_decode_steps": 12, "max_decode_steps": 16, - "prompt_allocate_ratio": 2.0, } ) @@ -41,4 +40,4 @@ def try_init(max_num_seqs): parser.add_argument("--artifact-path", type=str, default="dist") args = parser.parse_args() - test_insufficient_cache_blocks_fail(args.artifact_path) + _test_insufficient_cache_blocks_fail(args.artifact_path) diff --git a/serve/tests/unittest/test_engine_with_samplers.py b/serve/tests/unittest/test_engine_with_samplers.py index af10398fc2..57f8bb360b 100644 --- a/serve/tests/unittest/test_engine_with_samplers.py +++ b/serve/tests/unittest/test_engine_with_samplers.py @@ -31,7 +31,7 @@ def create_engine( "use_staging_engine": use_staging_engine, "max_num_sequences": max_num_sequences, "max_input_len": max_input_len, - # Use defaults for "min_decode_steps", "max_decode_steps", "prompt_allocate_ratio" + # Use defaults for "min_decode_steps", "max_decode_steps" }) if use_staging_engine: @@ -66,7 +66,7 @@ def create_request(idx, prompt, temp, max_tokens, stop, ignore_eos): debug_options = DebugOptions(ignore_eos = ignore_eos) ) -def test_max_tokens( +def _test_max_tokens( model_artifact_path, use_staging_engine, max_num_sequences=4, @@ -103,7 +103,7 @@ def test_max_tokens( engine.stop() -def test_max_context_length( +def _test_max_context_length( model_artifact_path, use_staging_engine, max_num_sequences=4, @@ -141,7 +141,7 @@ def test_max_context_length( engine.stop() -def test_ignore_eos( +def _test_ignore_eos( model_artifact_path, use_staging_engine, max_num_sequences=4, @@ -177,7 +177,7 @@ def test_ignore_eos( engine.stop() -def test_stop( +def _test_stop( model_artifact_path, use_staging_engine, max_num_sequences=4, @@ -227,11 +227,11 @@ def test_stop( args = parser.parse_args() model_artifact_path = os.path.join(args.artifact_path, args.local_id) - test_max_tokens(model_artifact_path, use_staging_engine=True) - test_max_tokens(model_artifact_path, use_staging_engine=False) - test_ignore_eos(model_artifact_path, use_staging_engine=True) - test_ignore_eos(model_artifact_path, use_staging_engine=False) - test_stop(model_artifact_path, use_staging_engine=False) - test_stop(model_artifact_path, use_staging_engine=True) - test_max_context_length(model_artifact_path, use_staging_engine=True) - test_max_context_length(model_artifact_path, use_staging_engine=False) + _test_max_tokens(model_artifact_path, use_staging_engine=True) + _test_max_tokens(model_artifact_path, use_staging_engine=False) + _test_ignore_eos(model_artifact_path, use_staging_engine=True) + _test_ignore_eos(model_artifact_path, use_staging_engine=False) + _test_stop(model_artifact_path, use_staging_engine=False) + _test_stop(model_artifact_path, use_staging_engine=True) + _test_max_context_length(model_artifact_path, use_staging_engine=True) + _test_max_context_length(model_artifact_path, use_staging_engine=False) diff --git a/serve/tests/unittest/test_staging_engine.py b/serve/tests/unittest/test_staging_engine.py new file mode 100644 index 0000000000..f44e46c24c --- /dev/null +++ b/serve/tests/unittest/test_staging_engine.py @@ -0,0 +1,357 @@ +from typing import Optional, Union + +from mlc_serve.engine import ( + ChatMessage, + FinishReason, + Request, + RequestId, + RequestOutput, + SamplingParams, + StoppingCriteria, + get_engine_config +) +from mlc_serve.model.base import ModelArtifactConfig +from mlc_serve.engine.model_module import ( + DecodeRequest, + KVCache, + PrefillRequest, + SequenceId, + TextGenerationResult, +) +from mlc_serve.engine.sync_engine import SynchronousInferenceEngine +from mlc_serve.engine.staging_engine import StagingInferenceEngine + +from mlc_serve.model.dummy_model import ( + DummyModelModule, + DummyTokenizerModule, +) + +def create_messages(prompt) -> list[ChatMessage]: + return [ChatMessage(role="user", content=prompt)] + +def get_output_for_request( + outputs: list[RequestOutput], request_id: RequestId +) -> Optional[RequestOutput]: + for o in outputs: + if o.request_id == request_id: + return o + return None + + +def test_single_request(): + engine = StagingInferenceEngine( + tokenizer_module=DummyTokenizerModule(), + model_module_loader=DummyModelModule, + model_module_loader_kwargs = { + "max_cached_tokens": 30 + } + ) + engine.start() + + request_id = "1" + engine.add( + [ + Request( + request_id=request_id, + messages=create_messages("test prompt"), + sampling_params=SamplingParams(temperature=1), + stopping_criteria=StoppingCriteria(max_tokens=5), + ), + ] + ) + + step = engine.step() + # execute everything to free server and allow it to destroy + for i in range(0,100): + engine.step() + engine.stop() + + assert step.outputs[0].request_id == request_id + assert step.outputs[0].error is None + assert len(step.outputs) == 1 + + +def test_single_request_step_to_finish(): + engine = StagingInferenceEngine( + tokenizer_module=DummyTokenizerModule(), + model_module_loader=DummyModelModule, + model_module_loader_kwargs = { + "max_cached_tokens": 30 + } + ) + engine.start() + + request_id = "1" + engine.add( + [ + Request( + request_id=request_id, + messages=create_messages("test prompt"), + sampling_params=SamplingParams(temperature=1), + stopping_criteria=StoppingCriteria(max_tokens=10), + ), + ] + ) + + steps = [engine.step() for _ in range(11)] + # execute everything to free server and allow it to destroy + for i in range(0,100): + engine.step() + engine.stop() + + assert steps[-1].outputs[0].request_id == request_id + assert steps[-1].outputs[0].sequences[0].finish_reason == FinishReason.Length + assert len(steps[-1].outputs) == 1 + + +def test_multiple_requests_wait_queue(): + engine = StagingInferenceEngine( + tokenizer_module=DummyTokenizerModule(), + model_module_loader=DummyModelModule, + model_module_loader_kwargs = { + "max_cached_tokens": 20 + } + ) + engine.start() + + request_id_1 = "1" + request_id_2 = "2" + + engine.add( + [ + Request( + request_id=request_id_1, + messages=create_messages("test " * 11), # 11 tokens + sampling_params=SamplingParams(temperature=1), + stopping_criteria=StoppingCriteria(max_tokens=2), + ), + ] + ) + + engine.add( + [ + Request( + request_id=request_id_2, + messages=create_messages("test " * 11), # 11 tokens + sampling_params=SamplingParams(temperature=1), + stopping_criteria=StoppingCriteria(max_tokens=2), + ), + ] + ) + + steps = [engine.step() for _ in range(3)] + # execute everything to free server and allow it to destroy + for i in range(0,100): + engine.step() + engine.stop() + + assert len(steps[0].outputs) == 1 + assert steps[0].outputs[0].request_id == request_id_1 + + assert len(steps[1].outputs) == 1 + assert steps[1].outputs[0].request_id == request_id_1 + + assert len(steps[2].outputs) == 2 + assert ( + get_output_for_request(steps[2].outputs, request_id_1) + .sequences[0] + .finish_reason + == FinishReason.Length + ) + assert get_output_for_request(steps[2].outputs, request_id_2) is not None + + +def test_multiple_requests_preempt(): + engine = StagingInferenceEngine( + tokenizer_module=DummyTokenizerModule(), + model_module_loader=DummyModelModule, + model_module_loader_kwargs = { + "max_cached_tokens": 30 + } + ) + engine.start() + + request_id_1 = "1" + request_id_2 = "2" + + engine.add( + [ + Request( + request_id=request_id_1, + messages=create_messages("test " * 10), + sampling_params=SamplingParams(temperature=1), + stopping_criteria=StoppingCriteria(max_tokens=7), + ), + ] + ) + + engine.add( + [ + Request( + request_id=request_id_2, + messages=create_messages("test " * 10), + sampling_params=SamplingParams(temperature=1), + stopping_criteria=StoppingCriteria(max_tokens=7), + ), + ] + ) + + steps = [engine.step() for _ in range(20)] + # execute everything to free server and allow it to destroy + for i in range(0,100): + engine.step() + engine.stop() + + # Due to asynchronious nature of request submission and processing, we cannot + # track exactly on each step certain request is processed + # but we can catch a pattern that it is two requests processed, then 1 and then again 2 + stage = 0 + + for s in steps: + if stage == 0 and len(s.outputs) == 2: + stage = 1 + if stage == 1 and len(s.outputs) == 1: + stage = 2 + if stage == 2 and len(s.outputs) == 2: + stage = 3 + if stage == 3 and len(s.outputs) == 1: + stage = 4 + assert s.outputs[0].is_finished + if stage == 4 and len(s.outputs) > 1: + stage = 5 + + assert stage == 4 + + +# Test to verify if evicted request from active batch which in intermediate +# state exceeding the max_num_batched_tokens can be processed successfully and will +# not hang the server in infinite attempt to return it back to the active loop +def test_cache_evict_hang_staging(): + engine = StagingInferenceEngine( + tokenizer_module=DummyTokenizerModule(), + model_module_loader=DummyModelModule, + model_module_loader_kwargs = { + "max_cached_tokens": 40, + "max_input_len": 10, + "max_num_sequences": 2 + } + ) + engine.start() + + request_id_1 = "1" + request_id_2 = "2" + + engine.add( + [ + Request( + request_id=request_id_1, + messages=create_messages("A " * 10), + sampling_params=SamplingParams(temperature=1), + stopping_criteria=StoppingCriteria(max_tokens=20), + ), + ] + ) + + engine.add( + [ + Request( + request_id=request_id_2, + messages=create_messages("A " * 10), + sampling_params=SamplingParams(temperature=1), + stopping_criteria=StoppingCriteria(max_tokens=20), + ), + ] + ) + + steps = [engine.step() for _ in range(40)] + engine.stop() + + finished = {} + empty_step = 0 + for s in steps: + for o in s.outputs: + if o.is_finished: + finished[o.request_id] = True + if not len(s.outputs): + empty_step += 1 + + assert len(finished) == 2 + assert empty_step < 10 + + +# Test to verify if new comming request with big prompt can be put into inference +# and does not have issues with cache size limits verification +def test_big_prompt_fit_to_cache_staging(): + engine = StagingInferenceEngine( + tokenizer_module=DummyTokenizerModule(), + model_module_loader=DummyModelModule, + model_module_loader_kwargs = { + "max_cached_tokens": 40, + "max_input_len": 30, + "max_num_sequences": 1 + } + ) + engine.start() + + request_id_1 = "1" + + engine.add( + [ + Request( + request_id=request_id_1, + messages=create_messages("A " * 30), + sampling_params=SamplingParams(temperature=1), + stopping_criteria=StoppingCriteria(max_tokens=5), + ), + ] + ) + + steps = [engine.step() for _ in range(40)] + engine.stop() + + finished = {} + return_token_step = 0 + for s in steps: + for o in s.outputs: + if o.is_finished: + finished[o.request_id] = True + if len(s.outputs): + return_token_step += 1 + + assert len(finished) == 1 + assert return_token_step >= 5 + + +# Test to verify if new comming request with big prompt is handled properly +def test_big_prompt_not_fit_to_cache(): + engine = StagingInferenceEngine( + tokenizer_module=DummyTokenizerModule(), + model_module_loader=DummyModelModule, + model_module_loader_kwargs = { + "max_cached_tokens": 29, + "max_input_len": 30, + "max_num_sequences": 1 + } + ) + engine.start() + + request_id_1 = "1" + + engine.add( + [ + Request( + request_id=request_id_1, + messages=create_messages("A " * 30), + sampling_params=SamplingParams(temperature=1), + stopping_criteria=StoppingCriteria(max_tokens=2), + ), + ] + ) + + steps = [engine.step() for _ in range(5)] + engine.stop() + + assert len(steps[0].outputs) == 1 + assert steps[0].outputs[0].is_finished + assert steps[0].outputs[0].error + assert len(steps[1].outputs) == 0 diff --git a/serve/tests/unittest/test_synchronous_inference_engine.py b/serve/tests/unittest/test_sync_engine.py similarity index 51% rename from serve/tests/unittest/test_synchronous_inference_engine.py rename to serve/tests/unittest/test_sync_engine.py index 3b7810cbbe..a63e2e791f 100644 --- a/serve/tests/unittest/test_synchronous_inference_engine.py +++ b/serve/tests/unittest/test_sync_engine.py @@ -2,7 +2,6 @@ from mlc_serve.engine import ( ChatMessage, - DebugOptions, FinishReason, Request, RequestId, @@ -13,123 +12,20 @@ ) from mlc_serve.model.base import ModelArtifactConfig from mlc_serve.engine.model_module import ( - ConversationTemplate, DecodeRequest, KVCache, - KVCacheManager, - ModelModule, PrefillRequest, SequenceId, TextGenerationResult, - TextGenerator, - Tokenizer, ) -from mlc_serve.engine.sync_engine import SynchronousInferenceEngine +from mlc_serve.engine.sync_engine import SynchronousInferenceEngine +from mlc_serve.engine.staging_engine import StagingInferenceEngine -class DummyTokenizer: - @property - def eos_token_id(self): - return 2 - - def encode(self, text: str, **kwargs) -> list[int]: - return [1] * len(text.split()) - - def decode(self, tokens: list[int], **kwargs) -> str: - return "test " * len(tokens) - - -class DummyConversationTemplate: - def apply(self, messages: list[ChatMessage]) -> str: - return " ".join(m.content for m in messages if m.content is not None) - - -class DummyCache: - def __init__(self, max_cached_tokens: int): - self.max_cached_tokens = max_cached_tokens - self.cached_requests = dict[RequestId, int]() - - -class DummyCacheManager: - def __init__(self, max_cached_tokens: int): - self.cache = DummyCache(max_cached_tokens) - - def get_cache(self) -> KVCache: - return self.cache - - def allocate(self, request_id: RequestId, num_tokens: int) -> bool: - self.cache.cached_requests[request_id] = num_tokens - if self.get_free_space() < 0: - raise RuntimeError("Cache out of space") - - def extend(self, sequence_id: SequenceId, new_tokens: int) -> bool: - if sequence_id.sequence_index > 0: - raise RuntimeError("Multiple generated sequences not supported") - self.cache.cached_requests[sequence_id.request_id] += new_tokens - if self.get_free_space() < 0: - raise RuntimeError("Cache out of space") - - def free(self, sequence_id: SequenceId): - if sequence_id.sequence_index > 0: - raise RuntimeError("Multiple generated sequences not supported") - del self.cache.cached_requests[sequence_id.request_id] - - def get_kv_cache_size(self) -> int: - return self.cache.max_cached_tokens - - def get_free_space(self) -> int: - return self.cache.max_cached_tokens - sum(self.cache.cached_requests.values()) - - def get_max_new_tokens(self) -> int: - if not self.cache.cached_requests: - return self.get_kv_cache_size() - return self.get_free_space() // len(self.cache.cached_requests) - - -class DummyTextGenerator: - def generate( - self, - requests: list[Union[PrefillRequest, DecodeRequest]], - kv_cache: KVCache, - ) -> list[TextGenerationResult]: - result = [] - for req in requests: - if isinstance(req, DecodeRequest): - request_id = req.sequence_id.request_id - if req.sequence_id.sequence_index > 0: - raise RuntimeError("Multiple generated sequences not supported") - else: - request_id = req.request_id - - if len(req.token_ids) > kv_cache.cached_requests[request_id]: - raise RuntimeError(f"Cache out of space for request {req.request_id}") - result.append( - TextGenerationResult( - sequence_id=SequenceId( - request_id=request_id, - sequence_index=0, - ), - generated_tokens=[1], - error=None, - ) - ) - return result - - -class DummaryModelModule: - def __init__(self, max_cached_tokens: int): - self.tokenizer = DummyTokenizer() - self.conversation_template = DummyConversationTemplate() - self.text_generator = DummyTextGenerator() - self.cache_manager = DummyCacheManager(max_cached_tokens) - self.model_artifact_config = ModelArtifactConfig._from_json({ - "max_context_length": 1024, - }) - self.engine_config = get_engine_config({ - "max_decode_steps": 0, - "min_decode_steps": 0, - "prompt_allocate_ratio": 1.0 - }, enable_check = False) +from mlc_serve.model.dummy_model import ( + DummyModelModule, + DummyTokenizerModule, +) def create_messages(prompt) -> list[ChatMessage]: @@ -146,7 +42,7 @@ def get_output_for_request( def test_single_request(): - engine = SynchronousInferenceEngine(DummaryModelModule(30)) + engine = SynchronousInferenceEngine(DummyModelModule(30)) request_id = "1" engine.add( [ @@ -167,7 +63,7 @@ def test_single_request(): def test_single_request_step_to_finish(): - engine = SynchronousInferenceEngine(DummaryModelModule(30)) + engine = SynchronousInferenceEngine(DummyModelModule(30)) request_id = "1" engine.add( @@ -189,7 +85,7 @@ def test_single_request_step_to_finish(): def test_multiple_requests_wait_queue(): - engine = SynchronousInferenceEngine(DummaryModelModule(20)) + engine = SynchronousInferenceEngine(DummyModelModule(20)) request_id_1 = "1" request_id_2 = "2" @@ -235,7 +131,7 @@ def test_multiple_requests_wait_queue(): def test_multiple_requests_preempt(): - engine = SynchronousInferenceEngine(DummaryModelModule(30)) + engine = SynchronousInferenceEngine(DummyModelModule(30)) request_id_1 = "1" request_id_2 = "2" @@ -272,3 +168,112 @@ def test_multiple_requests_preempt(): assert len(steps[7].outputs) == 2 assert get_output_for_request(steps[7].outputs, finished_request_id).is_finished + + +# Test to verify if evicted request from active batch which in intermediate +# state exceeding the max_num_batched_tokens can be processed successfully and will +# not hang the server in infinite attempt to return it back to the active loop +def test_cache_evict_hang(): + engine = SynchronousInferenceEngine(DummyModelModule(40, 10, 2)) + + request_id_1 = "1" + request_id_2 = "2" + + engine.add( + [ + Request( + request_id=request_id_1, + messages=create_messages("A " * 10), + sampling_params=SamplingParams(temperature=1), + stopping_criteria=StoppingCriteria(max_tokens=20), + ), + ] + ) + + engine.add( + [ + Request( + request_id=request_id_2, + messages=create_messages("A " * 10), + sampling_params=SamplingParams(temperature=1), + stopping_criteria=StoppingCriteria(max_tokens=20), + ), + ] + ) + + steps = [engine.step() for _ in range(40)] + + finished = {} + empty_step = 0 + for s in steps: + for o in s.outputs: + if o.is_finished: + finished[o.request_id] = True + if not len(s.outputs): + empty_step += 1 + + assert len(finished) == 2 + assert empty_step < 10 + + +# Test to verify if new comming request with big prompt can be put into inference +# and does not have issues with cache size limits verification +def test_big_prompt_fit_to_cache(): + engine = SynchronousInferenceEngine(DummyModelModule(40, 30, 1)) + + request_id_1 = "1" + + engine.add( + [ + Request( + request_id=request_id_1, + messages=create_messages("A " * 30), + sampling_params=SamplingParams(temperature=1), + stopping_criteria=StoppingCriteria(max_tokens=5), + ), + ] + ) + + steps = [engine.step() for _ in range(40)] + + finished = {} + return_token_step = 0 + for s in steps: + for o in s.outputs: + if o.is_finished: + finished[o.request_id] = True + if len(s.outputs): + return_token_step += 1 + + assert len(finished) == 1 + assert return_token_step >= 5 + + +# Test to verify if new comming request with big prompt is handled properly +def test_big_prompt_not_fit_to_cache(): + engine = SynchronousInferenceEngine(DummyModelModule(29, 30, 1)) + + request_id_1 = "1" + + engine.add( + [ + Request( + request_id=request_id_1, + messages=create_messages("A " * 30), + sampling_params=SamplingParams(temperature=1), + stopping_criteria=StoppingCriteria(max_tokens=2), + ), + ] + ) + + steps = [engine.step() for _ in range(5)] + + assert len(steps[0].outputs) == 1 + assert steps[0].outputs[0].is_finished + # TODO(amalyshe): the behaviour of sync and staged engines are not consistent + # Staging engine handles this situation better, it returns error and no sequences + assert steps[0].outputs[0].sequences[0].finish_reason == FinishReason.Cancelled + # TODO(amalyshe:) + # There must be error, but currently error is lost in the engine, need to fix + # assert steps[0].outputs[0].error + assert len(steps[1].outputs) == 0 \ No newline at end of file