Skip to content

Commit

Permalink
Fix the server hang in case of long prompts/requests in certain serve… (
Browse files Browse the repository at this point in the history
#89)

* Fix the server hang in case of long prompts/requests in certain server configuration

There is a discrepancy between limitations of the model, user and cache.
To avoid infinite loop we need to discard long requests exceeding
server limitations and/or handle long requests more agile

* remove unnecessary verification in add request

* Refactor validation conditions during request add

* Make one more condition simplification

* Remove configuration verification bypass and fix tests

* Fix conditins to discard requests from inference

* Add test on sync engine hang for two use cases

* Add staged engine tests

* Move DummaryModel from tests to mlc_serve

* Rename non pytest test functions to exclude from regualar pytest testing

* address PR comments

* Clean-up prompt_allocate_ratio/prompt-allocat-_ratio

* Rename unittest files

* Beautify comment

* Add test for request exceed cache limit, fix comment

* fix formatting

* Add comment

* Comments beautification

* Change Dummary nameing to Dummy
  • Loading branch information
elvin-n authored Nov 30, 2023
1 parent 352e5b6 commit 927f30e
Show file tree
Hide file tree
Showing 11 changed files with 704 additions and 187 deletions.
2 changes: 0 additions & 2 deletions serve/benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def create_engine_and_tokenizer_module(
"max_input_len": args.max_input_len,
"min_decode_steps": args.min_decode_steps,
"max_decode_steps": args.max_decode_steps,
"prompt_allocate_ratio": args.prompt_allocate_ratio
})

if args.use_staging_engine:
Expand Down Expand Up @@ -182,7 +181,6 @@ def main(args: argparse.Namespace):
parser.add_argument("--max-input-len", type=int, default=512)
parser.add_argument("--min-decode-steps", type=int, default=32)
parser.add_argument("--max-decode-steps", type=int, default=56)
parser.add_argument("--prompt-allocate-ratio", type=float, default=2.0)
parser.add_argument(
"--num-prompts", type=int, default=1000, help="Number of prompts to process."
)
Expand Down
40 changes: 18 additions & 22 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class MLCServeEngineConfig:
max_num_batched_tokens: int = -1
min_decode_steps: int = 32
max_decode_steps: int = 48
prompt_allocate_ratio: float = 2.0

@classmethod
def _from_json(config_cls, json_obj: Dict[Any, Any]):
Expand All @@ -39,30 +38,27 @@ def _from_json(config_cls, json_obj: Dict[Any, Any]):
}
)

def get_engine_config(dict_config, enable_check = True):
def get_engine_config(dict_config):
engine_config = MLCServeEngineConfig._from_json(dict_config)
# Checks to make sure engine configs are set correctly
# since engine config is critical to the performance
if enable_check:
assert isinstance(engine_config.use_staging_engine, bool)
assert isinstance(engine_config.max_num_batched_tokens, int)
assert isinstance(engine_config.max_input_len, int)
assert isinstance(engine_config.max_num_sequences, int)
assert isinstance(engine_config.max_decode_steps, int)
assert isinstance(engine_config.min_decode_steps, int)
assert isinstance(engine_config.prompt_allocate_ratio, float)

# TODO(@sunggg): engine allows -1 for these params. figure out the behavior and enable checks properly
assert engine_config.max_num_batched_tokens == -1, \
"`max_num_batched_tokens` is not supposed to be configured directly. \
Use `max_num_sequences` and `max_input_len` instead."
assert engine_config.max_input_len > 0
assert engine_config.max_num_sequences > 0
engine_config.max_num_batched_tokens = engine_config.max_num_sequences * engine_config.max_input_len

assert (engine_config.min_decode_steps > 0) and (engine_config.max_decode_steps > 0)
assert engine_config.max_decode_steps > engine_config.min_decode_steps
assert engine_config.prompt_allocate_ratio > 0
assert isinstance(engine_config.use_staging_engine, bool)
assert isinstance(engine_config.max_num_batched_tokens, int)
assert isinstance(engine_config.max_input_len, int)
assert isinstance(engine_config.max_num_sequences, int)
assert isinstance(engine_config.max_decode_steps, int)
assert isinstance(engine_config.min_decode_steps, int)

# TODO(@sunggg): engine allows -1 for these params. figure out the behavior and enable checks properly
assert engine_config.max_num_batched_tokens == -1, \
"`max_num_batched_tokens` is not supposed to be configured directly. \
Use `max_num_sequences` and `max_input_len` instead."
assert engine_config.max_input_len > 0
assert engine_config.max_num_sequences > 0
engine_config.max_num_batched_tokens = engine_config.max_num_sequences * engine_config.max_input_len

assert (engine_config.min_decode_steps > 0) and (engine_config.max_decode_steps > 0)
assert engine_config.max_decode_steps > engine_config.min_decode_steps

return engine_config

Expand Down
54 changes: 33 additions & 21 deletions serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ class GenerationLoopWorker:
max_num_batched_tokens: int
max_decode_steps: int
min_decode_steps: int
prompt_allocate_ratio: float
queue_lock: Lock
queue: Deque[RequestState]
has_new_requests: Condition
Expand All @@ -92,8 +91,6 @@ def __init__(
self.min_decode_steps = min(
self.max_decode_steps - 1, model_module.engine_config.min_decode_steps
)
self.prompt_allocate_ratio = model_module.engine_config.prompt_allocate_ratio
assert self.prompt_allocate_ratio >= 1.0

self.queue_lock = Lock()
self.queue = deque[RequestState]()
Expand All @@ -113,9 +110,14 @@ def add(self, request_states: list[RequestState]):
for request_state in request_states:
if (
request_state.validation_err is not None
or request_state.prompt_len >= self.max_context_length
or request_state.prompt_len > min(self.max_context_length, self.max_num_batched_tokens)
# Need to exclude requests which cannot fit into the kv_cache and can be processed
# at least max_decode_steps steps
or self.cache_manager.get_kv_cache_size() - request_state.prompt_len < self.max_decode_steps
):
self.cancelled_requests.append(request_state)
if request_state.validation_err is None:
request_state.validation_err = "The prompt is too long for the given set of engine parameters."
else:
valid_states.append(request_state)

Expand Down Expand Up @@ -194,24 +196,25 @@ def step(self) -> GenerationLoopWorkerOutput:

self.stopped_requests.clear()

for state in self.cancelled_requests:
err = None
if state.validation_err:
err = state.validation_err
with self.queue_lock:
for state in self.cancelled_requests:
err = None
if state.validation_err:
err = state.validation_err

outputs.append(
SequenceGenerationOutput(
# TODO: support multi-sequence
id=SequenceId(state.request_id, 0),
new_tokens=[],
finish_reason=FinishReason.Cancelled,
error=err,
outputs.append(
SequenceGenerationOutput(
# TODO: support multi-sequence
id=SequenceId(state.request_id, 0),
new_tokens=[],
finish_reason=FinishReason.Cancelled,
error=err,
)
)
)
if state.request_id in self.current_batch:
self._remove_request_from_batch(state.request_id)
if state.request_id in self.current_batch:
self._remove_request_from_batch(state.request_id)

self.cancelled_requests.clear()
self.cancelled_requests.clear()

self._adjust_batch()

Expand Down Expand Up @@ -297,15 +300,24 @@ def _adjust_batch(self):
state = self.queue[0]
num_tokens = len(state.token_ids)
num_new_batched_tokens += num_tokens
# This can happen when we are recovering from cache eviction and the sum of prompt
# and intermediate decode tokens is bigger than the biggest allowable batch size,
# self.max_num_batched_tokens. In such cases, we need to discard the recent decode
# tokens that cannot fit into a batch, and recompute them after we fill the cache
# entries for the older tokens.
if len(self.current_batch) == 0 and num_new_batched_tokens > self.max_num_batched_tokens:
state.token_ids = state.token_ids[:self.max_num_batched_tokens]
state.next_start_position = num_new_batched_tokens = num_tokens = self.max_num_batched_tokens
if num_new_batched_tokens > self.max_num_batched_tokens > 0:
LOG.debug(
"Stop growing the batch due to max_num_batched_tokens. Batched tokens: %s",
num_new_batched_tokens,
)
break
# Make sure to leave some free space in the KV cache after a request is added or batched
if (
self.cache_manager.get_free_space()
<= self.prompt_allocate_ratio * num_tokens
(self.cache_manager.get_free_space() - num_tokens) / (len(self.current_batch) + 1)
< self.max_decode_steps
):
LOG.debug(
"Stop growing the batch due to not enough free space. Free: %s, Num tokens: %s",
Expand Down
27 changes: 21 additions & 6 deletions serve/mlc_serve/engine/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class SynchronousInferenceEngine(InferenceEngine):
max_num_batched_tokens: int
max_decode_steps: int
min_decode_steps: int
prompt_allocate_ratio: float
queue_lock: Lock
queue: Deque[RequestState]
has_new_requests: Condition
Expand All @@ -60,15 +59,14 @@ def __init__(
assert self.model_artifact_config.max_context_length, "max_context_length must not be zero"
self.max_context_length = self.model_artifact_config.max_context_length
self.max_num_batched_tokens = model_module.engine_config.max_num_batched_tokens
assert self.max_num_batched_tokens > 0, "max_num_batched_tokens must be positive"
self.max_decode_steps = min(
self.cache_manager.get_kv_cache_size(),
model_module.engine_config.max_decode_steps,
)
self.min_decode_steps = min(
self.max_decode_steps - 1, model_module.engine_config.min_decode_steps
)
self.prompt_allocate_ratio = model_module.engine_config.prompt_allocate_ratio
assert self.prompt_allocate_ratio >= 1.0

self.queue_lock = Lock()
self.queue = deque[RequestState]()
Expand Down Expand Up @@ -98,8 +96,16 @@ def add(self, requests: list[Request]):
state = self._get_new_request_state(req)
new_request_states.append(state)

if state.prompt_len >= self.max_context_length:
if (
state.validation_err is not None
or state.prompt_len > min(self.max_context_length, self.max_num_batched_tokens)
# Need to exclude requests which cannot fit into the kv_cache and can be processed
# at least max_decode_steps steps
or self.cache_manager.get_kv_cache_size() - state.prompt_len < self.max_decode_steps
):
self.cancel(req.request_id)
if state.validation_err is None:
state.validation_err = "The prompt is too long for the given set of engine parameters."

with self.queue_lock:
self.queue.extend(new_request_states)
Expand Down Expand Up @@ -279,15 +285,24 @@ def _adjust_batch(self):
state = self.queue[0]
num_tokens = len(state.token_ids)
num_new_batched_tokens += num_tokens
# This can happen when we are recovering from cache eviction and the sum of prompt
# and intermediate decode tokens is bigger than the biggest allowable batch size,
# self.max_num_batched_tokens. In such cases, we need to discard the recent decode
# tokens that cannot fit into a batch, and recompute them after we fill the cache
# entries for the older tokens.
if not len(self.current_batch) and num_new_batched_tokens > self.max_num_batched_tokens:
state.token_ids = state.token_ids[:self.max_num_batched_tokens]
state.next_start_position = num_new_batched_tokens = num_tokens = self.max_num_batched_tokens
if num_new_batched_tokens > self.max_num_batched_tokens > 0:
logger.debug(
"Stop growing the batch due to max_num_batched_tokens. Batched tokens: %s",
num_new_batched_tokens,
)
break
# Make sure to leave some free space in the KV cache after a request is added or batched
if (
self.cache_manager.get_free_space()
<= self.prompt_allocate_ratio * num_tokens
(self.cache_manager.get_free_space() - num_tokens) / (len(self.current_batch) + 1)
< self.max_decode_steps
):
logger.debug(
"Stop growing the batch due to not enough free space. Free: %s, Num tokens: %s",
Expand Down
138 changes: 138 additions & 0 deletions serve/mlc_serve/model/dummy_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from typing import Optional, Union

from mlc_serve.engine import (
ChatMessage,
DebugOptions,
FinishReason,
Request,
RequestId,
RequestOutput,
SamplingParams,
StoppingCriteria,
get_engine_config
)
from mlc_serve.model.base import ModelArtifactConfig
from mlc_serve.engine.model_module import (
ConversationTemplate,
DecodeRequest,
KVCache,
KVCacheManager,
ModelModule,
PrefillRequest,
SequenceId,
TextGenerationResult,
TextGenerator,
Tokenizer,
)

class DummyTokenizer:
@property
def eos_token_id(self):
return 2

def encode(self, text: str, **kwargs) -> list[int]:
return [1] * len(text.split())

def decode(self, tokens: list[int], **kwargs) -> str:
return "test " * len(tokens)


class DummyConversationTemplate:
def apply(self, messages: list[ChatMessage]) -> str:
return " ".join(m.content for m in messages if m.content is not None)


class DummyCache:
def __init__(self, max_cached_tokens: int):
self.max_cached_tokens = max_cached_tokens
self.cached_requests = dict[RequestId, int]()


class DummyCacheManager:
def __init__(self, max_cached_tokens: int):
self.cache = DummyCache(max_cached_tokens)

def get_cache(self) -> KVCache:
return self.cache

def allocate(self, request_id: RequestId, num_tokens: int) -> bool:
self.cache.cached_requests[request_id] = num_tokens
if self.get_free_space() < 0:
raise RuntimeError("Cache out of space")

def extend(self, sequence_id: SequenceId, new_tokens: int) -> bool:
if sequence_id.sequence_index > 0:
raise RuntimeError("Multiple generated sequences not supported")
self.cache.cached_requests[sequence_id.request_id] += new_tokens
if self.get_free_space() < 0:
raise RuntimeError("Cache out of space")

def free(self, sequence_id: SequenceId):
if sequence_id.sequence_index > 0:
raise RuntimeError("Multiple generated sequences not supported")
del self.cache.cached_requests[sequence_id.request_id]

def get_kv_cache_size(self) -> int:
return self.cache.max_cached_tokens

def get_free_space(self) -> int:
return self.cache.max_cached_tokens - sum(self.cache.cached_requests.values())

def get_max_new_tokens(self) -> int:
if not self.cache.cached_requests:
return self.get_kv_cache_size()
return self.get_free_space() // len(self.cache.cached_requests)


class DummyTextGenerator:
def generate(
self,
requests: list[Union[PrefillRequest, DecodeRequest]],
kv_cache: KVCache,
) -> list[TextGenerationResult]:
result = []
for req in requests:
if isinstance(req, DecodeRequest):
request_id = req.sequence_id.request_id
if req.sequence_id.sequence_index > 0:
raise RuntimeError("Multiple generated sequences not supported")
else:
request_id = req.request_id

if len(req.token_ids) > kv_cache.cached_requests[request_id]:
raise RuntimeError(f"Cache out of space for request {req.request_id}")
result.append(
TextGenerationResult(
sequence_id=SequenceId(
request_id=request_id,
sequence_index=0,
),
generated_tokens=[1],
error=None,
)
)
return result


class DummyModelModule:
def __init__(self, max_cached_tokens: int, max_input_len = 512, max_num_sequences = 8):
self.tokenizer = DummyTokenizer()
self.conversation_template = DummyConversationTemplate()
self.text_generator = DummyTextGenerator()
self.cache_manager = DummyCacheManager(max_cached_tokens)
self.model_artifact_config = ModelArtifactConfig._from_json({
"max_context_length": 1024,
})
self.engine_config = get_engine_config({
"max_decode_steps": 2,
"min_decode_steps": 1,
"use_staging_engine" : False,
"max_input_len": max_input_len,
"max_num_sequences": max_num_sequences
})


class DummyTokenizerModule:
def __init__(self):
self.tokenizer = DummyTokenizer()
self.conversation_template = DummyConversationTemplate()
2 changes: 0 additions & 2 deletions serve/mlc_serve/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def parse_args():
args.add_argument("--max-input-len", type=int, default=512)
args.add_argument("--min-decode-steps", type=int, default=12)
args.add_argument("--max-decode-steps", type=int, default=16)
args.add_argument("--prompt-allocate-ratio", type=float, default=2.0)
args.add_argument("--debug-logging", action="store_true")
parsed = args.parse_args()
return parsed
Expand Down Expand Up @@ -64,7 +63,6 @@ def create_engine(
"max_input_len": args.max_input_len,
"min_decode_steps": args.min_decode_steps,
"max_decode_steps": args.max_decode_steps,
"prompt_allocate_ratio": args.prompt_allocate_ratio
})

if args.use_staging_engine:
Expand Down
Loading

0 comments on commit 927f30e

Please sign in to comment.