From c1c8d39633808081b7aae982a3c20044c44ec701 Mon Sep 17 00:00:00 2001 From: br3no Date: Thu, 8 Feb 2024 16:53:11 +0100 Subject: [PATCH 1/2] #288 guided decoding Added support for guided decoding in `api_server` by integrating _outlines_ (https://github.com/outlines-dev/outlines). --- requirements-guided-decoding.txt | 1 + tests/samplers/test_sampler.py | 3 +- vllm/entrypoints/api_server.py | 37 +++++++++++++++++ vllm/model_executor/guided_decoding.py | 56 ++++++++++++++++++++++++++ vllm/model_executor/layers/sampler.py | 3 +- vllm/sampling_params.py | 7 ++-- 6 files changed, 102 insertions(+), 5 deletions(-) create mode 100644 requirements-guided-decoding.txt create mode 100644 vllm/model_executor/guided_decoding.py diff --git a/requirements-guided-decoding.txt b/requirements-guided-decoding.txt new file mode 100644 index 0000000000000..04ef481381d00 --- /dev/null +++ b/requirements-guided-decoding.txt @@ -0,0 +1 @@ +outlines == 0.0.27 \ No newline at end of file diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index d34f32d03fee0..1523ca2be7fb8 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -230,7 +230,8 @@ def test_sampler_logits_processors(seed: int, device: str): # This sample logits processor gives infinite score to the i-th token, # where i is the length of the input sequence. # We therefore expect the output token sequence to be [0, 1, 2, ...] - def pick_ith(token_ids, logits): + # Since this processor is stateless, the seq_id is not used + def pick_ith(token_ids, logits, seq_id): logits[len(token_ids)] = float("inf") return logits diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index f7b8d258fae4c..f20b7b1fd3392 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -10,6 +10,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid +from vllm.model_executor.guided_decoding import GuidedDecodingEngine, GuidedDecodingMode, get_logits_processor TIMEOUT_KEEP_ALIVE = 5 # seconds. app = FastAPI() @@ -28,6 +29,8 @@ async def generate(request: Request) -> Response: The request should be a JSON object with the following fields: - prompt: the prompt to use for the generation. + - schema: the JSON schema to use for the generation (if regex is not provided). + - regex: the regex to use for the generation (if schema is not provided). - stream: whether to stream the results or not. - other fields: the sampling parameters (See `SamplingParams` for details). """ @@ -35,6 +38,11 @@ async def generate(request: Request) -> Response: prompt = request_dict.pop("prompt") prefix_pos = request_dict.pop("prefix_pos", None) stream = request_dict.pop("stream", False) + + if args.guided_decoding_engine is not None: + # Add logits processors if guided decoding is requested + _setup_guided_decoding(request_dict) + sampling_params = SamplingParams(**request_dict) request_id = random_uuid() @@ -72,6 +80,28 @@ async def stream_results() -> AsyncGenerator[bytes, None]: return JSONResponse(ret) +def _setup_guided_decoding(request_dict): + json_schema = request_dict.pop("schema", None) + regex_string = request_dict.pop("regex", None) + + if json_schema is not None or regex_string is not None: + assert json_schema is None or regex_string is None, \ + "Only one of 'schema' and 'regex' can be provided." + + guided_decoding_engine = GuidedDecodingEngine( + args.guided_decoding_engine) + mode = GuidedDecodingMode( + "schema" if json_schema is not None else "regex") + logits_processors = [ + get_logits_processor(json_schema or regex_string, mode, + guided_decoding_engine, engine.engine) + ] + if request_dict.get("logits_processors") is None: + request_dict["logits_processors"] = logits_processors + else: + request_dict["logits_processors"].extend(logits_processors) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default=None) @@ -83,6 +113,13 @@ async def stream_results() -> AsyncGenerator[bytes, None]: type=str, default=None, help="FastAPI root_path when app is behind a path based routing proxy") + parser.add_argument( + "--guided-decoding-engine", + type=str, + default=None, + help= + "What engine for guided decoding to use. Currently only `oulines` is supported." + ) parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py new file mode 100644 index 0000000000000..84845b0e241a9 --- /dev/null +++ b/vllm/model_executor/guided_decoding.py @@ -0,0 +1,56 @@ +from enum import Enum +import time +from typing import List, Union +try: + from outlines.serve.vllm import JSONLogitsProcessor, RegexLogitsProcessor +except ImportError: + raise ValueError("Please install 'outlines' (pip install outlines) to use guided generation.") +import torch + +from vllm.engine.llm_engine import LLMEngine +from vllm.entrypoints.llm import LLM +from vllm.sampling_params import LogitsProcessor + +class GuidedDecodingEngine(Enum): + OUTLINES = "outlines" + +class GuidedDecodingMode(Enum): + REGEX = "regex" + JSON_SCHEMA = "schema" + +class OutlinesJSONLogitsProcessor(JSONLogitsProcessor): + + def __init__(self, json_schema: dict, llm: LLM): + super().__init__(json_schema, llm) + + def __call__( + self, + input_ids: List[int], + scores: torch.Tensor, + seq_id: int, + ) -> torch.Tensor: + return super().__call__(seq_id, input_ids, scores) + + +class OulinesRegexLogitsProcessor(RegexLogitsProcessor): + + def __init__(self, regex: str, llm: LLM): + super().__init__(regex, llm) + + def __call__( + self, + input_ids: List[int], + scores: torch.Tensor, + seq_id: int, + ) -> torch.Tensor: + return super().__call__(seq_id, input_ids, scores) + + +def get_logits_processor(specification: Union[str, dict], mode: GuidedDecodingMode, engine: GuidedDecodingEngine, llm_engine: LLMEngine): + if engine == GuidedDecodingEngine.OUTLINES: + if mode == GuidedDecodingMode.JSON_SCHEMA: + return OutlinesJSONLogitsProcessor(specification, llm_engine) + elif mode == GuidedDecodingMode.REGEX: + return OulinesRegexLogitsProcessor(specification, llm_engine) + else: + raise ValueError(f"Unknown mode: {mode}") \ No newline at end of file diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index bc86a916b5bbf..250a9f14d77a9 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -153,7 +153,8 @@ def _apply_logits_processors( logits_row = logits[logits_row_idx] token_ids = sampling_metadata.seq_data[seq_id].output_token_ids for logits_processor in logits_processors: - logits_row = logits_processor(token_ids, logits_row) + logits_row = logits_processor(token_ids, logits_row, + seq_id) logits[logits_row_idx] = logits_row logits_row_idx += 1 else: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index bb7d0002c910c..a3c577dbd96e4 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -14,10 +14,11 @@ class SamplingType(IntEnum): BEAM = 2 -LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor] +LogitsProcessor = Callable[[List[int], torch.Tensor, int], torch.Tensor] """LogitsProcessor is a function that takes a list of previously generated -tokens and a tensor of the logits for the next token, and returns a modified -tensor of logits to sample from.""" +tokens, a tensor of the logits for the next token and an integer sequence id, +and returns a modified tensor of logits to sample from. The sequence id is used +to distinguish different generations, in case the processor is stateful.""" class SamplingParams: From 8b0395a5089170064ab22e51d77fb5d13bad43d1 Mon Sep 17 00:00:00 2001 From: br3no Date: Thu, 8 Feb 2024 17:20:13 +0100 Subject: [PATCH 2/2] Fixing ruff and yapf complaints --- vllm/model_executor/guided_decoding.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index 84845b0e241a9..ea0c1ebb59689 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -1,23 +1,26 @@ from enum import Enum -import time from typing import List, Union try: from outlines.serve.vllm import JSONLogitsProcessor, RegexLogitsProcessor -except ImportError: - raise ValueError("Please install 'outlines' (pip install outlines) to use guided generation.") +except ImportError as e: + raise ValueError( + "Please install 'outlines' (pip install outlines) to use guided generation." + ) from e import torch from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.llm import LLM -from vllm.sampling_params import LogitsProcessor + class GuidedDecodingEngine(Enum): OUTLINES = "outlines" + class GuidedDecodingMode(Enum): REGEX = "regex" JSON_SCHEMA = "schema" + class OutlinesJSONLogitsProcessor(JSONLogitsProcessor): def __init__(self, json_schema: dict, llm: LLM): @@ -45,12 +48,14 @@ def __call__( ) -> torch.Tensor: return super().__call__(seq_id, input_ids, scores) - -def get_logits_processor(specification: Union[str, dict], mode: GuidedDecodingMode, engine: GuidedDecodingEngine, llm_engine: LLMEngine): + +def get_logits_processor(specification: Union[str, + dict], mode: GuidedDecodingMode, + engine: GuidedDecodingEngine, llm_engine: LLMEngine): if engine == GuidedDecodingEngine.OUTLINES: if mode == GuidedDecodingMode.JSON_SCHEMA: return OutlinesJSONLogitsProcessor(specification, llm_engine) elif mode == GuidedDecodingMode.REGEX: return OulinesRegexLogitsProcessor(specification, llm_engine) else: - raise ValueError(f"Unknown mode: {mode}") \ No newline at end of file + raise ValueError(f"Unknown mode: {mode}")