Skip to content

Commit

Permalink
Merge branch 'SUPPORT_GUIDED_DECODING-vllm-project#288' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
br3no committed Feb 8, 2024
2 parents 931746b + c1c8d39 commit 04f1e19
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 5 deletions.
1 change: 1 addition & 0 deletions requirements-guided-decoding.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
outlines == 0.0.27
3 changes: 2 additions & 1 deletion tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ def test_sampler_logits_processors(seed: int, device: str):
# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def pick_ith(token_ids, logits):
# Since this processor is stateless, the seq_id is not used
def pick_ith(token_ids, logits, seq_id):
logits[len(token_ids)] = float("inf")
return logits

Expand Down
37 changes: 37 additions & 0 deletions vllm/entrypoints/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from vllm.model_executor.guided_decoding import GuidedDecodingEngine, GuidedDecodingMode, get_logits_processor

TIMEOUT_KEEP_ALIVE = 5 # seconds.
app = FastAPI()
Expand All @@ -28,13 +29,20 @@ async def generate(request: Request) -> Response:
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- schema: the JSON schema to use for the generation (if regex is not provided).
- regex: the regex to use for the generation (if schema is not provided).
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
prefix_pos = request_dict.pop("prefix_pos", None)
stream = request_dict.pop("stream", False)

if args.guided_decoding_engine is not None:
# Add logits processors if guided decoding is requested
_setup_guided_decoding(request_dict)

sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()

Expand Down Expand Up @@ -72,6 +80,28 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
return JSONResponse(ret)


def _setup_guided_decoding(request_dict):
json_schema = request_dict.pop("schema", None)
regex_string = request_dict.pop("regex", None)

if json_schema is not None or regex_string is not None:
assert json_schema is None or regex_string is None, \
"Only one of 'schema' and 'regex' can be provided."

guided_decoding_engine = GuidedDecodingEngine(
args.guided_decoding_engine)
mode = GuidedDecodingMode(
"schema" if json_schema is not None else "regex")
logits_processors = [
get_logits_processor(json_schema or regex_string, mode,
guided_decoding_engine, engine.engine)
]
if request_dict.get("logits_processors") is None:
request_dict["logits_processors"] = logits_processors
else:
request_dict["logits_processors"].extend(logits_processors)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default=None)
Expand All @@ -83,6 +113,13 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument(
"--guided-decoding-engine",
type=str,
default=None,
help=
"What engine for guided decoding to use. Currently only `oulines` is supported."
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()

Expand Down
56 changes: 56 additions & 0 deletions vllm/model_executor/guided_decoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from enum import Enum
import time
from typing import List, Union
try:
from outlines.serve.vllm import JSONLogitsProcessor, RegexLogitsProcessor
except ImportError:
raise ValueError("Please install 'outlines' (pip install outlines) to use guided generation.")
import torch

from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.sampling_params import LogitsProcessor

class GuidedDecodingEngine(Enum):
OUTLINES = "outlines"

class GuidedDecodingMode(Enum):
REGEX = "regex"
JSON_SCHEMA = "schema"

class OutlinesJSONLogitsProcessor(JSONLogitsProcessor):

def __init__(self, json_schema: dict, llm: LLM):
super().__init__(json_schema, llm)

def __call__(
self,
input_ids: List[int],
scores: torch.Tensor,
seq_id: int,
) -> torch.Tensor:
return super().__call__(seq_id, input_ids, scores)


class OulinesRegexLogitsProcessor(RegexLogitsProcessor):

def __init__(self, regex: str, llm: LLM):
super().__init__(regex, llm)

def __call__(
self,
input_ids: List[int],
scores: torch.Tensor,
seq_id: int,
) -> torch.Tensor:
return super().__call__(seq_id, input_ids, scores)


def get_logits_processor(specification: Union[str, dict], mode: GuidedDecodingMode, engine: GuidedDecodingEngine, llm_engine: LLMEngine):
if engine == GuidedDecodingEngine.OUTLINES:
if mode == GuidedDecodingMode.JSON_SCHEMA:
return OutlinesJSONLogitsProcessor(specification, llm_engine)
elif mode == GuidedDecodingMode.REGEX:
return OulinesRegexLogitsProcessor(specification, llm_engine)
else:
raise ValueError(f"Unknown mode: {mode}")
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def _apply_logits_processors(
logits_row = logits[logits_row_idx]
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
logits_row = logits_processor(token_ids, logits_row,
seq_id)
logits[logits_row_idx] = logits_row
logits_row_idx += 1
else:
Expand Down
7 changes: 4 additions & 3 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ class SamplingType(IntEnum):
BEAM = 2


LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
LogitsProcessor = Callable[[List[int], torch.Tensor, int], torch.Tensor]
"""LogitsProcessor is a function that takes a list of previously generated
tokens and a tensor of the logits for the next token, and returns a modified
tensor of logits to sample from."""
tokens, a tensor of the logits for the next token and an integer sequence id,
and returns a modified tensor of logits to sample from. The sequence id is used
to distinguish different generations, in case the processor is stateful."""


class SamplingParams:
Expand Down

0 comments on commit 04f1e19

Please sign in to comment.