Skip to content

Commit

Permalink
docs: improved documentation of supported model config params through…
Browse files Browse the repository at this point in the history
… vllm
  • Loading branch information
benlipkin committed Oct 23, 2024
1 parent b481be3 commit 5c170fd
Showing 1 changed file with 15 additions and 36 deletions.
51 changes: 15 additions & 36 deletions decoding/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@
This module wraps the vLLM library to provide a simple interface via the
`LanguageModel` class. The class provides methods for conditionally generating
strings with `generate` and scoring strings with `surprise`. An easy constructor
is also provided to load a model by its Hugging Face model ID and manage memory, etc.
strings with `LanguageModel.generate` and scoring strings with `LanguageModel.surprise`.
An easy constructor is also provided to load a model by its Hugging Face model ID
and specify optional parameters for memory management, KV caching, scheduling policy,
quantization, LORA, speculative decoding, and many other settings.
"""

# ruff: noqa: E402

import logging
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TypedDict, TypeGuard, Unpack
from typing import TypeGuard, Unpack

import jax.numpy as jnp

Expand All @@ -23,42 +25,13 @@
_logger = logging.getLogger(__name__)
_logger.info("Importing vLLM: This may take a moment...")

from vllm import LLM, SamplingParams
from vllm import LLM, EngineArgs, SamplingParams
from vllm.inputs import PromptType
from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer


class ModelParams(TypedDict, total=False):
"""
Parameters for the LanguageModel constructor. These are optional and
correspond to the arguments for initializing
[`vllm.LLM`](https://docs.vllm.ai/en/latest/dev/offline_inference/llm.html).
See the vLLM documentation for more details.
"""

tokenizer: str | None
tokenizer_mode: str
skip_tokenizer_init: bool
trust_remote_code: bool
tensor_parallel_size: int
dtype: str
quantization: str | None
revision: str | None
tokenizer_revision: str | None
seed: int
gpu_memory_utilization: float
swap_space: float
cpu_offload_gb: float
enforce_eager: bool | None
max_context_len_to_capture: int | None
max_seq_len_to_capture: int
disable_custom_all_reduce: bool
disable_async_output_proc: bool
enable_prefix_caching: bool


@dataclass(frozen=True, kw_only=True)
class LanguageModel:
"""
Expand Down Expand Up @@ -209,14 +182,18 @@ def surprise(self, *, contexts: Sequence[str], queries: Sequence[str]) -> FVX:
def from_id(
cls,
model_id: str,
**model_kwargs: Unpack[ModelParams],
**model_kwargs: Unpack[EngineArgs], # type: ignore[reportGeneralTypeIssues]
) -> "LanguageModel":
"""
Load a language model by its Hugging Face model ID.
Args:
model_id: The Hugging Face model ID.
model_kwargs: Optional parameters for the model constructor.
**model_kwargs: Optional parameters for the model constructor. These are
passed to the [`vllm.LLM`](https://docs.vllm.ai/en/stable/dev/offline_inference/llm.html)
constructor, and through there to [`vllm.EngineArgs`](https://docs.vllm.ai/en/stable/models/engine_args.html).
Check the linked vLLM documentation for more details on what parameters
are available.
Returns:
A `LanguageModel` instance.
Expand All @@ -225,7 +202,9 @@ def from_id(
```python
from decoding.models import LanguageModel
llm = LanguageModel.from_id("gpt2", gpu_memory_utilization=0.5)
llm = LanguageModel.from_id(
"gpt2", gpu_memory_utilization=0.5, enable_prefix_caching=True
)
assert llm.tokenizer.name_or_path == "gpt2"
```
Expand Down

0 comments on commit 5c170fd

Please sign in to comment.