diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a8b6103..19fcd76 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.8 + rev: v0.7.0 hooks: - id: ruff args: [ --fix ] diff --git a/decoding/models.py b/decoding/models.py index 9225673..b0b955d 100644 --- a/decoding/models.py +++ b/decoding/models.py @@ -24,6 +24,7 @@ _logger.info("Importing vLLM: This may take a moment...") from vllm import LLM, SamplingParams +from vllm.inputs import PromptType from vllm.outputs import RequestOutput from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -55,6 +56,7 @@ class ModelParams(TypedDict, total=False): max_seq_len_to_capture: int disable_custom_all_reduce: bool disable_async_output_proc: bool + enable_prefix_caching: bool @dataclass(frozen=True, kw_only=True) @@ -269,11 +271,11 @@ def _is_valid_logp( raise ValueError(msg) -def _guard_prompt_text(x: str | None) -> str: - if x is None: - msg = "Prompt text should not be None" - raise ValueError(msg) - return x +def _guard_prompt_text(x: PromptType | None) -> str: + if isinstance(x, str): # `PromptType` is a union type of `str` and other types + return x + msg = "Prompt text should be a string" + raise ValueError(msg) def _guard_output_logp(x: float | None) -> float: diff --git a/setup.py b/setup.py index a41a847..181fdaf 100644 --- a/setup.py +++ b/setup.py @@ -9,16 +9,16 @@ core_requirements = [ "jax==0.4.31", "jaxtyping==0.2.34", - "vllm==0.6.2", + "vllm==0.6.3.post1", ] dev_requirements = [ - "pdoc==14.7.0", - "pre-commit==3.8.0", - "pyright==1.1.382.post1", + "pdoc==15.0.0", + "pre-commit==4.0.1", + "pyright==1.1.385", "pytest==8.3.3", "pytest-cov==5.0.0", "pytest-html==4.1.1", - "ruff==0.6.8", + "ruff==0.7.0", ] with pathlib.Path("README.md").open(encoding="utf-8") as f: @@ -26,7 +26,7 @@ setuptools.setup( name="decoding", - version="0.1.1", + version="0.1.2", description="Composable inference algorithms with LLMs and programmable logic", long_description=readme, long_description_content_type="text/markdown",