Skip to content

Commit

Permalink
Misc clean up, doc improvement (#98)
Browse files Browse the repository at this point in the history
doc improvement, minor clean up
  • Loading branch information
masahi authored Dec 6, 2023
1 parent 21ee0bb commit f55e6f6
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 24 deletions.
2 changes: 1 addition & 1 deletion serve/mlc_serve/engine/staging_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def start(self):
LOG.info("StagingInferenceEngine.start")
try:
self.worker_process.start()
if not self.ready_event.wait(timeout=90):
if not self.ready_event.wait(timeout=120):
raise RuntimeError(
"StagingInferenceEngine worker is not ready before timeout."
)
Expand Down
48 changes: 29 additions & 19 deletions serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,19 +120,23 @@ def add(self, request_states: list[RequestState]):
# States which have been invalidated should never be added, directly
# cancel them instead.
valid_states = []
for request_state in request_states:
kv_cache_size = self.cache_manager.get_kv_cache_size()
max_prompt_len = min(self.max_context_length, self.max_num_batched_tokens)
for state in request_states:
if (
request_state.validation_err is not None
or request_state.prompt_len > min(self.max_context_length, self.max_num_batched_tokens)
# Need to exclude requests which cannot fit into the kv_cache and can be processed
# at least max_decode_steps steps
or self.cache_manager.get_kv_cache_size() - request_state.prompt_len < self.max_decode_steps
state.validation_err is not None
or state.prompt_len > max_prompt_len
# We make sure that the KV cache will have enough free space for this request to proceed
# decoding for at least self.max_decode_steps steps.
or (kv_cache_size - state.prompt_len) < self.max_decode_steps
):
self.cancelled_requests.append(request_state)
if request_state.validation_err is None:
request_state.validation_err = ValidationError("The prompt is too long for the given set of engine parameters.")
self.cancelled_requests.append(state)
if state.validation_err is None:
state.validation_err = ValidationError(
"The prompt is too long for the given set of engine parameters."
)
else:
valid_states.append(request_state)
valid_states.append(state)

self.queue.extend(valid_states)
self.has_new_requests.notify_all()
Expand Down Expand Up @@ -211,6 +215,7 @@ def step(self) -> GenerationLoopWorkerOutput:
self.stopped_requests.clear()

with self.queue_lock:
# Hold the lock here since self.cancelled_requests is modified in add(...) as well.
for state in self.cancelled_requests:
err = None
if state.validation_err:
Expand Down Expand Up @@ -328,20 +333,25 @@ def _adjust_batch(self):
# self.max_num_batched_tokens. In such cases, we need to discard the recent decode
# tokens that cannot fit into a batch, and recompute them after we fill the cache
# entries for the older tokens.
if len(self.current_batch) == 0 and num_new_batched_tokens > self.max_num_batched_tokens:
state.token_ids = state.token_ids[:self.max_num_batched_tokens]
state.next_start_position = num_new_batched_tokens = num_tokens = self.max_num_batched_tokens
if num_new_batched_tokens > self.max_num_batched_tokens > 0:
if (
len(self.current_batch) == 0
and num_new_batched_tokens > self.max_num_batched_tokens
):
state.token_ids = state.token_ids[: self.max_num_batched_tokens]
state.next_start_position = (
num_new_batched_tokens
) = num_tokens = self.max_num_batched_tokens
if num_new_batched_tokens > self.max_num_batched_tokens:
LOG.debug(
"Stop growing the batch due to max_num_batched_tokens. Batched tokens: %s",
num_new_batched_tokens,
)
break
# Make sure to leave some free space in the KV cache after a request is added or batched
if (
(self.cache_manager.get_free_space() - num_tokens) / (len(self.current_batch) + 1)
< self.max_decode_steps
):
# We make sure that the KV cache will have enough free space for all sequences in the batch
# to proceed decoding for at least self.max_decode_steps steps.
if (self.cache_manager.get_free_space() - num_tokens) / (
len(self.current_batch) + 1
) < self.max_decode_steps:
LOG.debug(
"Stop growing the batch due to not enough free space. Free: %s, Num tokens: %s",
self.cache_manager.get_free_space(),
Expand Down
9 changes: 5 additions & 4 deletions serve/mlc_serve/engine/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
RequestState,
SequenceOutput,
check_stopping_sequences,
ValidationError
ValidationError,
)
from .model_module import DecodeRequest, ModelModule, PrefillRequest, SequenceId, TextGenerator, Tokenizer as TokenizerP
from ..model.base import ModelArtifactConfig
Expand Down Expand Up @@ -96,8 +96,8 @@ def add(self, requests: list[Request]):
if (
state.validation_err is not None
or state.prompt_len > min(self.max_context_length, self.max_num_batched_tokens)
# Need to exclude requests which cannot fit into the kv_cache and can be processed
# at least max_decode_steps steps
# We make sure that the KV cache will have enough free space for this request to proceed
# decoding for at least self.max_decode_steps steps.
or self.cache_manager.get_kv_cache_size() - state.prompt_len < self.max_decode_steps
):
self.cancel(req.request_id)
Expand Down Expand Up @@ -296,7 +296,8 @@ def _adjust_batch(self):
num_new_batched_tokens,
)
break
# Make sure to leave some free space in the KV cache after a request is added or batched
# We make sure that the KV cache will have enough free space for this request to proceed
# decoding for at least self.max_decode_steps steps.
if (
(self.cache_manager.get_free_space() - num_tokens) / (len(self.current_batch) + 1)
< self.max_decode_steps
Expand Down

0 comments on commit f55e6f6

Please sign in to comment.