Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Infer] Add Simple Protocol for simple request and response #244

Merged
merged 12 commits into from
Jun 21, 2024
20 changes: 11 additions & 9 deletions examples/inference/api_server_simple/query_dynamic_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
import aiohttp
import argparse
from typing import Dict, Union
from llm_on_ray.inference.api_simple_backend.simple_protocol import (
SimpleRequest,
SimpleModelResponse,
)

parser = argparse.ArgumentParser(
description="Example script to query with multiple requests", add_help=True
Expand Down Expand Up @@ -63,9 +67,8 @@
config["top_k"] = float(args.top_k)


async def send_query(session, endpoint, prompt, config):
json_request = {"text": prompt, "config": config}
async with session.post(endpoint, json=json_request) as resp:
async def send_query(session, endpoint, req):
async with session.post(endpoint, json=req.dict()) as resp:
return await resp.text()


Expand All @@ -86,16 +89,15 @@ async def send_query(session, endpoint, prompt, config):

configs = [config1] * 5 + [config2] * (len(prompts) - 5)

reqs = [SimpleRequest(text=prompt, config=config) for prompt, config in zip(prompts, configs)]


async def send_all_query(endpoint, prompts, configs):
async def send_all_query(endpoint, reqs):
async with aiohttp.ClientSession() as session:
tasks = [
send_query(session, endpoint, prompt, config)
for prompt, config in zip(prompts, configs)
]
tasks = [send_query(session, endpoint, req) for req in reqs]
responses = await asyncio.gather(*tasks)
print("\n--------------\n".join(responses))
print("\nTotal responses:", len(responses))


asyncio.run(send_all_query(args.model_endpoint, prompts, configs))
asyncio.run(send_all_query(args.model_endpoint, reqs))
14 changes: 10 additions & 4 deletions examples/inference/api_server_simple/query_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
import requests
import argparse
from typing import Dict, Union
from llm_on_ray.inference.api_simple_backend.simple_protocol import (
SimpleRequest,
SimpleModelResponse,
)

parser = argparse.ArgumentParser(
description="Example script to query with single request", add_help=True
Expand Down Expand Up @@ -66,20 +70,22 @@
if args.top_k:
config["top_k"] = float(args.top_k)

sample_input = {"text": prompt, "config": config, "stream": args.streaming_response}
sample_input = SimpleRequest(text=prompt, config=config, stream=args.streaming_response)

proxies = {"http": None, "https": None}
outputs = requests.post(
args.model_endpoint,
proxies=proxies, # type: ignore
json=sample_input,
json=sample_input.dict(),
stream=args.streaming_response,
)

outputs.raise_for_status()

simple_response = SimpleModelResponse.from_requests_response(outputs)
if args.streaming_response:
for output in outputs.iter_content(chunk_size=None, decode_unicode=True):
for output in simple_response.iter_content(chunk_size=1, decode_unicode=True):
print(output, end="", flush=True)
print()
else:
print(outputs.text, flush=True)
print(simple_response.text, flush=True)
91 changes: 91 additions & 0 deletions llm_on_ray/inference/api_simple_backend/simple_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#
# Copyright 2023 The LLM-on-Ray Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Dict, Optional, Union, Iterator, List
import requests
from pydantic import BaseModel, ValidationError, validator


class SimpleRequest(BaseModel):
text: str
config: Dict[str, Union[int, float]] = {}
stream: Optional[bool] = False

@validator("text")
def text_must_not_be_empty(cls, v):
if not v.strip():
raise ValueError("Empty prompt is not supported.")
return v

@validator("config", pre=True)
def check_config_type(cls, value):
allowed_keys = ["max_new_tokens", "temperature", "top_p", "top_k"]
allowed_set = set(allowed_keys)
config_dict = value.keys()
config_keys = [key for key in config_dict]
config_set = set(config_keys)

if not isinstance(value, dict):
raise ValueError("Config must be a dictionary")

if not all(isinstance(key, str) for key in value.keys()):
raise ValueError("All keys in config must be strings")

if not all(isinstance(val, (int, float)) for val in value.values()):
raise ValueError("All values in config must be integers or floats")

if not config_set.issubset(allowed_set):
invalid_keys = config_set - allowed_set
raise ValueError(f'Invalid config keys: {", ".join(invalid_keys)}')

return value

@validator("stream", pre=True)
def check_stream_type(cls, value):
if not isinstance(value, bool) and value is not None:
raise ValueError("Stream must be a boolean or None")
return value


class SimpleModelResponse(BaseModel):
headers: Dict[str, str]
text: str
content: bytes
status_code: int
url: str

class Config:
arbitrary_types_allowed = True

response: Optional[requests.Response] = None

@staticmethod
def from_requests_response(response: requests.Response):
return SimpleModelResponse(
headers=dict(response.headers),
text=response.text,
content=response.content,
status_code=response.status_code,
url=response.url,
response=response,
)

def iter_content(
self, chunk_size: Optional[int] = 1, decode_unicode: bool = False
) -> Iterator[Union[bytes, str]]:
if self.response is not None:
return self.response.iter_content(chunk_size=chunk_size, decode_unicode=decode_unicode)
else:
return iter([])
20 changes: 9 additions & 11 deletions llm_on_ray/inference/predictor_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
ErrorResponse,
ModelResponse,
)
from llm_on_ray.inference.api_simple_backend.simple_protocol import (
SimpleRequest,
SimpleModelResponse,
)
from llm_on_ray.inference.predictor import GenerateInput
from llm_on_ray.inference.utils import get_prompt_format, PromptFormat
from llm_on_ray.inference.api_openai_backend.tools import OpenAIToolsPrompter, ChatPromptCapture
Expand Down Expand Up @@ -379,24 +383,18 @@ def preprocess_prompts(

async def __call__(self, http_request: Request) -> Union[StreamingResponse, JSONResponse, str]:
self.use_openai = False

try:
json_request: Dict[str, Any] = await http_request.json()
request: Dict[str, Any] = await http_request.json()
except ValueError:
return JSONResponse(
status_code=400,
content="Invalid JSON format from http request.",
)
streaming_response = json_request["stream"] if "stream" in json_request else False
input = json_request["text"] if "text" in json_request else ""

if input == "":
return JSONResponse(
status_code=400,
content="Empty prompt is not supported.",
)
config = json_request["config"] if "config" in json_request else {}
# return prompt or list of prompts preprocessed
streaming_response = request["stream"]
input = request["text"]
config = request["config"]

prompts = self.preprocess_prompts(input)

# Handle streaming response
Expand Down
90 changes: 90 additions & 0 deletions tests/inference/test_simple_protocal.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filename: protocal => protocol

Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#
# Copyright 2023 The LLM-on-Ray Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import subprocess
import pytest
import os
from basic_set import start_serve
import requests
from llm_on_ray.inference.api_simple_backend.simple_protocol import (
SimpleRequest,
SimpleModelResponse,
)


executed_models = []


# Parametrize the test function with different combinations of parameters
# TODO: more models and combinations will be added and tested.
@pytest.mark.parametrize(
"prompt,streaming_response,max_new_tokens,temperature,top_p, top_k",
[
(
prompt,
streaming_response,
max_new_tokens,
temperature,
top_p,
top_k,
)
for prompt in ["Once upon a time", ""]
for streaming_response in [None, True, "error"]
for max_new_tokens in [None, 128, "error"]
for temperature in [None]
for top_p in [None]
for top_k in [None]
],
)
def test_script(prompt, streaming_response, max_new_tokens, temperature, top_p, top_k):
global executed_models

# Check if this modelname has already executed start_serve
if "gpt2" not in executed_models:
start_serve("gpt2", simple=True)
# Mark this modelname has already executed start_serve
executed_models.append("gpt2")
config = {}
if max_new_tokens:
config["max_new_tokens"] = max_new_tokens
if temperature:
config["temperature"] = temperature
if top_p:
config["top_p"] = top_p
if top_k:
config["top_k"] = top_k

try:
sample_input = SimpleRequest(text=prompt, config=config, stream=streaming_response)
except ValueError as e:
print(e)
return
outputs = requests.post(
"http://localhost:8000/gpt2",
proxies={"http": None, "https": None}, # type: ignore
json=sample_input.dict(),
stream=streaming_response,
)

outputs.raise_for_status()

simple_response = SimpleModelResponse.from_requests_response(outputs)
if streaming_response:
for output in simple_response.iter_content(chunk_size=1, decode_unicode=True):
print(output, end="", flush=True)
print()
else:
print(simple_response.text, flush=True)
Loading