Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 8, 2023
1 parent 18b8e41 commit bdb0be3
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
8 changes: 4 additions & 4 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
check_stopping_sequences,
)
from .metrics import PrometheusMetrics
from .metrics_labels import *
from .metrics_labels import NUM_CACHE_EVICTONS
from .model_module import (
DecodeRequest,
PrefillRequest,
Expand Down Expand Up @@ -242,7 +242,7 @@ def evict_request(self):
request_to_remove = min(
self.current_batch.values(), key=lambda s: s.num_total_tokens
)
# TODO parallel sampling: Properly support Evicting a multi-sequence request
# TODO parallel sampling: Properly support evicting a multi-sequence request
assert (
self.current_batch[request_to_remove.request_id].num_sequences == 1
), "Evicting a multi-sequence request is not supported."
Expand All @@ -262,7 +262,7 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]:
"Stop growing the batch due to min_decode_steps. Decode steps: %s",
max_new_tokens,
)
# stop adding request if there isn't enough space to do a certain steps of decoding.
# stop adding request if there isn't enough space to do self.min_decode_steps steps of decoding.
return None

state = self.queue[0]
Expand Down Expand Up @@ -295,7 +295,7 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]:
num_tokens = state.prompt_len
num_new_batched_tokens += num_tokens

if num_new_batched_tokens > self.max_num_batched_tokens > 0:
if num_new_batched_tokens > self.max_num_batched_tokens:
LOG.debug(
"Stop growing the batch due to max_num_batched_tokens. Batched tokens: %s",
num_new_batched_tokens,
Expand Down
4 changes: 2 additions & 2 deletions serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ def run_generation_loop_worker(
result_queue: multiprocessing.Queue,
ready_event: multiprocessing.synchronize.Event,
contextvars: Optional[Dict[str, Any]] = None,
enable_json_logs=False,
log_level="INFO",
enable_json_logs: bool = False,
log_level: str = "INFO",
):
configure_logging(enable_json_logs, log_level)
structlog.contextvars.bind_contextvars(**contextvars)
Expand Down
18 changes: 9 additions & 9 deletions serve/mlc_serve/engine/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,22 +78,19 @@ def add(self, requests: list[Request]):
self.queue.extend(new_request_states)
self.has_new_requests.notify_all()

def has_pending_requests(self) -> bool:
return bool(self.queue or self.current_batch)

def wait_for_request(self, timeout_seconds=None) -> bool:
with self.queue_lock:
return self.has_new_requests.wait_for(
self.has_pending_requests, timeout=timeout_seconds
)

def cancel(self, request_id: RequestId):
with self.queue_lock:
# TODO: consider iterating throught the queue to find if request id exist
# Otherwise cancel a request that's already finished will leave request_id
# in the `requests_to_be_cancelled` set forever.
self.requests_to_be_cancelled.add(request_id)

def wait_for_request(self, timeout_seconds=None) -> bool:
with self.queue_lock:
return self.has_new_requests.wait_for(
self.has_pending_requests, timeout=timeout_seconds
)

def step(self) -> InferenceStepResult:
logger.debug("Starting new inference step.")

Expand Down Expand Up @@ -253,6 +250,9 @@ def _adjust_batch(self):
num_new_batched_tokens = self.try_grow_batch(num_new_batched_tokens)
self._discard_cancelled_requests_from_queue()

def has_pending_requests(self) -> bool:
return bool(self.queue or self.current_batch)

def _discard_cancelled_requests_from_queue(self):
"""
Requires the self.queue_lock to be held before calling this function.
Expand Down

0 comments on commit bdb0be3

Please sign in to comment.