Skip to content

Commit

Permalink
Various updates
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill committed Nov 6, 2024
1 parent 11498be commit e5414b4
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 114 deletions.
44 changes: 44 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2038,3 +2038,47 @@ def __post_init__(self):
self.model_config is not None and self.load_config is not None:
self.quant_config = VllmConfig._get_quantization_config(
self.model_config, self.load_config)

def __str__(self):
return ("model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"override_neuron_config=%s, "
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"pipeline_parallel_size=%d, "
"disable_custom_all_reduce=%s, quantization=%s, "
"enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, "
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
"use_async_output_proc=%s, mm_processor_kwargs=%s") % \
(self.model_config.model, self.speculative_config,
self.model_config.tokenizer,
self.model_config.skip_tokenizer_init,
self.model_config.tokenizer_mode,
self.model_config.revision,
self.model_config.override_neuron_config,
self.model_config.rope_scaling,
self.model_config.rope_theta,
self.model_config.tokenizer_revision,
self.model_config.trust_remote_code,
self.model_config.dtype,
self.model_config.max_model_len,
self.load_config.download_dir,
self.load_config.load_format,
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size,
self.parallel_config.disable_custom_all_reduce,
self.model_config.quantization,
self.model_config.enforce_eager,
self.cache_config.cache_dtype,
self.model_config.quantization_param_path,
self.device_config.device, self.decoding_config,
self.observability_config, self.model_config.seed,
self.model_config.served_model_name,
self.scheduler_config.num_scheduler_steps,
self.cache_config.enable_prefix_caching,
self.model_config.use_async_output_proc,
self.model_config.mm_processor_kwargs)
6 changes: 5 additions & 1 deletion vllm/engine/output_processor/stop_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ def check_stop_strings(
"""Check if any stop strings are matched and truncate sequence
output text accordingly.
Returns the stop string if matched or else None.
Returns tuple (stop_string, offset) if matched or else None.
Where stop_string is the matched stop string and offset is the
length to which output_text should be truncated, or -1 for no
truncation.
"""
if not new_char_count or not stop:
return None
Expand Down
2 changes: 1 addition & 1 deletion vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def new(
token_ids: List[int],
finished: bool = False,
) -> "RequestOutput":
"""Initialize a new "empty" RequestOutput object."""
"""Initialize a new RequestOutput object."""

# TODO: Support `n` > 1.
completion_output = CompletionOutput(
Expand Down
12 changes: 9 additions & 3 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class DetokenizerRequest:
include_stop_str_in_output: bool


class EngineCoreRequest(msgspec.Struct):
class EngineCoreRequest(msgspec.Struct, omit_defaults=True):

# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
# but this object is currently not playing well with msgspec
Expand All @@ -42,7 +42,10 @@ class EngineCoreRequest(msgspec.Struct):
lora_request: Optional[LoRARequest]


class EngineCoreOutput(msgspec.Struct, array_like=True):
class EngineCoreOutput(msgspec.Struct,
array_like=True,
omit_defaults=True,
gc=False):

request_id: str
new_token_ids: List[int]
Expand All @@ -51,7 +54,10 @@ class EngineCoreOutput(msgspec.Struct, array_like=True):
stop_reason: Union[int, str, None] = None


class EngineCoreOutputs(msgspec.Struct, array_like=True):
class EngineCoreOutputs(msgspec.Struct,
array_like=True,
omit_defaults=True,
gc=False):

#NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout and using an int enum for finish/stop reason
Expand Down
99 changes: 50 additions & 49 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import asyncio
from functools import partial
from typing import (AsyncGenerator, Dict, List, Mapping, Optional, Type,
Union)
from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union

from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
Expand Down Expand Up @@ -57,6 +55,8 @@ def __init__(

# Request streams (map of request_id -> AsyncStream).
self.request_streams: Dict[str, AsyncStream] = {}
# List of cancelled request ids to be aborted.
self.client_aborted_requests: List[str] = []

# Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor(vllm_config.model_config,
Expand All @@ -75,28 +75,26 @@ def __init__(
asyncio_mode=True,
)

self.is_output_handler_running = False
self.output_handler = None

@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
engine_config: Optional[VllmConfig] = None,
vllm_config: Optional[VllmConfig] = None,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine":
"""Creates an AsyncLLMEngine from the EngineArgs."""
"""Creates an AsyncLLM from the EngineArgs."""

# Create the engine configs.
if engine_config is None:
if vllm_config is None:
vllm_config = engine_args.create_engine_config()
else:
vllm_config = engine_config

executor_class = cls._get_executor_cls(vllm_config)

# Create the AsyncLLMEngine.
# Create the AsyncLLM.
return cls(
vllm_config=vllm_config,
executor_class=executor_class,
Expand All @@ -110,8 +108,8 @@ def from_engine_args(
def shutdown(self):
self.engine_core.shutdown()

if hasattr(self, "output_handler"):
self.output_handler.cancel()
if handler := getattr(self, "output_handler", None):
handler.cancel()

@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig):
Expand Down Expand Up @@ -151,7 +149,11 @@ async def add_request(
# 5) Return the generator.
return stream.generator()


# TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion
# requests we don't need to send multiple messages to core proc,
# and so we don't need multiple streams which then get
# re-multiplexed in the API server anyhow.
async def generate(
self,
prompt: PromptType,
Expand All @@ -166,10 +168,9 @@ async def generate(
# We start the output_handler on the first call to generate() so that
# we can call __init__ before the event loop starts, which enables us
# to handle startup failure gracefully in the OpenAI server.
if not self.is_output_handler_running:
if self.output_handler is None:
self.output_handler = asyncio.create_task(
self._run_output_handler())
self.is_output_handler_running = True

async for output in await self.add_request(
request_id,
Expand All @@ -182,37 +183,31 @@ async def generate(
):
yield output


async def _abort_requests(
self,
request_ids: Union[str, List[str]],
*,
exception: Optional[Union[BaseException, Type[BaseException]]] = None,
verbose: bool,
) -> None:
def _process_cancellations(self) -> List[str]:
"""
Abort requests. This function is called in two places:
* In output_handler_loop, if stop string is detected
* In iterate_with_cancellation (inside AsyncStream) when
a request disconnects. This function is as a callback.
Process requests cancelled from user side since last iteration.
"""
client_aborted_reqs = self.client_aborted_requests
if not client_aborted_reqs:
return []

# Convert to a list if we got a single request_id str.
if isinstance(request_ids, str):
request_ids = [request_ids]
reqs_to_abort = client_aborted_reqs.copy()
client_aborted_reqs.clear()

# Abort in EngineCore and Detokenizer.
await self.engine_core.abort_requests_async(request_ids)
self.detokenizer.abort_requests(request_ids)
self.detokenizer.abort_requests(reqs_to_abort)
for request_id in reqs_to_abort:
self._finish_stream(request_id)

# Remove from the request streams.
for request_id in request_ids:
stream = self.request_streams.pop(request_id, None)
if stream is not None:
stream.finish(exception=exception)
return reqs_to_abort

if verbose:
logger.info("Aborted request %s.", request_id)
def _finish_stream(self, request_id: str):
stream = self.request_streams.pop(request_id)
if stream is not None:
stream.finish()

async def _abort_engine_requests(self, request_ids: List[str]):
if request_ids:
await self.engine_core.abort_requests_async(request_ids)

def _add_request_to_streams(
self,
Expand All @@ -224,8 +219,9 @@ def _add_request_to_streams(
if request_id in self.request_streams:
raise ValueError(f"Request id {request_id} already running.")

abort_callback = partial(self._abort_requests, verbose=verbose)
stream = AsyncStream(request_id, abort_callback)
# Avoid streams having circular ref to parent AsyncLLM object.
aborted_reqs = self.client_aborted_requests
stream = AsyncStream(request_id, aborted_reqs.append)
self.request_streams[request_id] = stream

if verbose:
Expand All @@ -241,12 +237,13 @@ def _process_request_outputs(self, request_outputs: List[RequestOutput]):
assert request_id in self.request_streams

# Each request in the API server pulls from these streams.
self.request_streams[request_id].put(request_output)
stream = self.request_streams.get(request_id)
if stream is not None:
stream.put(request_output)

# If finished, remove from the tracker.
if request_output.finished:
self.request_streams[request_id].finish()
self.request_streams.pop(request_id)
# If finished, remove from the tracker.
if request_output.finished:
self._finish_stream(request_id)

async def _run_output_handler(self):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
Expand All @@ -262,8 +259,12 @@ async def _run_output_handler(self):
# 3) Put the RequestOutputs into the per-request AsyncStreams.
self._process_request_outputs(request_outputs)

# 3) Put the RequestOutputs into the per-request AsyncStreams.
cancelled_reqs_to_abort = self._process_cancellations()
reqs_to_abort.extend(cancelled_reqs_to_abort)

# Abort any requests that finished due to stop strings.
await self._abort_requests(reqs_to_abort, verbose=False)
await self._abort_engine_requests(reqs_to_abort)

except BaseException as e:
logger.error(e)
Expand Down Expand Up @@ -310,7 +311,7 @@ async def do_log_stats(
logger.debug("Called do_log_stats.")

async def check_health(self) -> None:
logger.debug("Called do_log_stats.")
logger.debug("Called check_health.")

async def start_profile(self) -> None:
raise ValueError("Not supported on V1 yet.")
Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/engine/async_stream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from typing import Any, Awaitable, AsyncGenerator, Callable, Optional, Type, Union
from typing import (Any, AsyncGenerator, Awaitable, Callable, Optional, Type,
Union)

from vllm.outputs import EmbeddingRequestOutput, RequestOutput

Expand Down
Loading

0 comments on commit e5414b4

Please sign in to comment.