Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for guided decoding (fixes #288) #2815

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements-guided-decoding.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
outlines == 0.0.27
3 changes: 2 additions & 1 deletion tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ def test_sampler_logits_processors(seed: int, device: str):
# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def pick_ith(token_ids, logits):
# Since this processor is stateless, the seq_id is not used
def pick_ith(token_ids, logits, seq_id):
logits[len(token_ids)] = float("inf")
return logits

Expand Down
37 changes: 37 additions & 0 deletions vllm/entrypoints/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from vllm.model_executor.guided_decoding import GuidedDecodingEngine, GuidedDecodingMode, get_logits_processor

TIMEOUT_KEEP_ALIVE = 5 # seconds.
app = FastAPI()
Expand All @@ -28,13 +29,20 @@ async def generate(request: Request) -> Response:

The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- schema: the JSON schema to use for the generation (if regex is not provided).
- regex: the regex to use for the generation (if schema is not provided).
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
prefix_pos = request_dict.pop("prefix_pos", None)
stream = request_dict.pop("stream", False)

if args.guided_decoding_engine is not None:
# Add logits processors if guided decoding is requested
_setup_guided_decoding(request_dict)

sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()

Expand Down Expand Up @@ -72,6 +80,28 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
return JSONResponse(ret)


def _setup_guided_decoding(request_dict):
json_schema = request_dict.pop("schema", None)
regex_string = request_dict.pop("regex", None)

if json_schema is not None or regex_string is not None:
assert json_schema is None or regex_string is None, \
"Only one of 'schema' and 'regex' can be provided."

guided_decoding_engine = GuidedDecodingEngine(
args.guided_decoding_engine)
mode = GuidedDecodingMode(
"schema" if json_schema is not None else "regex")
logits_processors = [
get_logits_processor(json_schema or regex_string, mode,
guided_decoding_engine, engine.engine)
]
if request_dict.get("logits_processors") is None:
request_dict["logits_processors"] = logits_processors
else:
request_dict["logits_processors"].extend(logits_processors)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default=None)
Expand All @@ -83,6 +113,13 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument(
"--guided-decoding-engine",
type=str,
default=None,
help=
"What engine for guided decoding to use. Currently only `oulines` is supported."
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()

Expand Down
61 changes: 61 additions & 0 deletions vllm/model_executor/guided_decoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from enum import Enum
from typing import List, Union
try:
from outlines.serve.vllm import JSONLogitsProcessor, RegexLogitsProcessor
except ImportError as e:
raise ValueError(
"Please install 'outlines' (pip install outlines) to use guided generation."
) from e
import torch

from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM


class GuidedDecodingEngine(Enum):
OUTLINES = "outlines"


class GuidedDecodingMode(Enum):
REGEX = "regex"
JSON_SCHEMA = "schema"


class OutlinesJSONLogitsProcessor(JSONLogitsProcessor):

def __init__(self, json_schema: dict, llm: LLM):
super().__init__(json_schema, llm)

def __call__(
self,
input_ids: List[int],
scores: torch.Tensor,
seq_id: int,
) -> torch.Tensor:
return super().__call__(seq_id, input_ids, scores)


class OulinesRegexLogitsProcessor(RegexLogitsProcessor):

def __init__(self, regex: str, llm: LLM):
super().__init__(regex, llm)

def __call__(
self,
input_ids: List[int],
scores: torch.Tensor,
seq_id: int,
) -> torch.Tensor:
return super().__call__(seq_id, input_ids, scores)


def get_logits_processor(specification: Union[str,
dict], mode: GuidedDecodingMode,
engine: GuidedDecodingEngine, llm_engine: LLMEngine):
if engine == GuidedDecodingEngine.OUTLINES:
if mode == GuidedDecodingMode.JSON_SCHEMA:
return OutlinesJSONLogitsProcessor(specification, llm_engine)
elif mode == GuidedDecodingMode.REGEX:
return OulinesRegexLogitsProcessor(specification, llm_engine)
else:
raise ValueError(f"Unknown mode: {mode}")
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def _apply_logits_processors(
logits_row = logits[logits_row_idx]
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
logits_row = logits_processor(token_ids, logits_row,
seq_id)
logits[logits_row_idx] = logits_row
logits_row_idx += 1
else:
Expand Down
7 changes: 4 additions & 3 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ class SamplingType(IntEnum):
BEAM = 2


LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
LogitsProcessor = Callable[[List[int], torch.Tensor, int], torch.Tensor]
"""LogitsProcessor is a function that takes a list of previously generated
tokens and a tensor of the logits for the next token, and returns a modified
tensor of logits to sample from."""
tokens, a tensor of the logits for the next token and an integer sequence id,
and returns a modified tensor of logits to sample from. The sequence id is used
to distinguish different generations, in case the processor is stateful."""


class SamplingParams:
Expand Down
Loading