Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
minmingzhu committed Apr 29, 2024
1 parent 60d9f15 commit 8698a17
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 177 deletions.
7 changes: 6 additions & 1 deletion examples/inference/api_server_simple/query_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@
)

args = parser.parse_args()
prompt = "Once upon a time,"
# prompt = "Once upon a time,"
prompt = [
{"role": "user", "content": "Which is bigger, the moon or the sun?"},
]


config: Dict[str, Union[int, float]] = {}
if args.max_new_tokens:
config["max_new_tokens"] = int(args.max_new_tokens)
Expand Down
2 changes: 1 addition & 1 deletion llm_on_ray/finetune/finetune_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class General(BaseModel):
enable_gradient_checkpointing: bool = False
chat_template: Optional[str] = None
default_chat_template: str = (
"{{ bos_token }}"
"Below is an instruction that describes a task. Write a response that appropriately completes the request."
"{% if messages[0]['role'] == 'system' %}"
"{{ raise_exception('System role not supported') }}"
"{% endif %}"
Expand Down
122 changes: 32 additions & 90 deletions llm_on_ray/inference/chat_template_process.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,4 @@
#
# Copyright 2023 The LLM-on-Ray Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Union

from llm_on_ray.inference.api_openai_backend.openai_protocol import ChatMessage


Expand All @@ -23,57 +7,31 @@ def __init__(self, predictor) -> None:
self.predictor = predictor

def get_prompt(self, input: List, is_mllm=False):
"""Generate response based on input."""
if self.predictor.infer_conf.model_description.chat_template is not None:
self.predictor.tokenizer.chat_template = (
self.predictor.infer_conf.model_description.chat_template
self.predictor.tokenizer.chat_template = (
self.predictor.infer_conf.model_description.chat_template
or self.predictor.tokenizer.chat_template
or self.predictor.infer_conf.model_description.default_chat_template
)

if isinstance(input, list) and input and isinstance(input[0], (ChatMessage, dict)):
messages = (
[dict(chat_message) for chat_message in input]
if isinstance(input[0], ChatMessage)
else input
)
elif self.predictor.tokenizer.chat_template is None:
self.predictor.tokenizer.chat_template = (
self.predictor.infer_conf.model_description.default_chat_template
prompt = self.predictor.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)

if is_mllm:
if isinstance(input, List):
if isinstance(input, list) and input and isinstance(input[0], ChatMessage):
messages = []
for chat_message in input:
message = {
"role": chat_message.role,
"content": chat_message.content,
}
messages.append(message)
texts, images = self._extract_messages(messages)
elif isinstance(input, list) and input and isinstance(input[0], dict):
texts, images = self._extract_messages(input)
elif isinstance(input, list) and input and isinstance(input[0], list):
texts, images = [self._extract_messages(p) for p in input]

if is_mllm:
texts, images = self._extract_messages(messages)
image = self._prepare_image(images)
prompt = self.predictor.tokenizer.apply_chat_template(texts, tokenize=False)
return prompt, image
else:
if isinstance(input, list) and input and isinstance(input[0], dict):
prompt = self.predictor.tokenizer.apply_chat_template(input, tokenize=False)
elif isinstance(input, list) and input and isinstance(input[0], list):
prompt = [
self.predictor.tokenizer.apply_chat_template(t, tokenize=False) for t in input
]
elif isinstance(input, list) and input and isinstance(input[0], ChatMessage):
messages = []
for chat_message in input:
message = {"role": chat_message.role, "content": chat_message.content}
messages.append(message)
prompt = self.predictor.tokenizer.apply_chat_template(messages, tokenize=False)
elif isinstance(input, list) and input and isinstance(input[0], str):
prompt = input
elif isinstance(input, str):
prompt = input
else:
raise TypeError(
f"Unsupported type {type(input)} for text. Expected dict or list of dicts."
prompt = self.predictor.tokenizer.apply_chat_template(
texts, add_generation_prompt=True, tokenize=False
)
return prompt
return prompt, image
return prompt

raise TypeError(f"Unsupported type {type(input)} for text. Expected dict or list of dicts.")

def _extract_messages(self, messages):
texts, images = [], []
Expand All @@ -88,39 +46,23 @@ def _extract_messages(self, messages):
return texts, images

def _prepare_image(self, messages: list):
"""Prepare image from history messages."""
from PIL import Image
import requests
from io import BytesIO
import base64
import re

# prepare images
images: List = []
if isinstance(messages[0], List):
for i in range(len(messages)):
for msg in messages[i]:
msg = dict(msg)
content = msg["content"]
if "url" not in content:
continue
is_data = len(re.findall("^data:image/.+;base64,", content["url"])) > 0
if is_data:
encoded_str = re.sub("^data:image/.+;base64,", "", content["url"])
images[i].append(Image.open(BytesIO(base64.b64decode(encoded_str))))
else:
images[i].append(Image.open(requests.get(content["url"], stream=True).raw))
elif isinstance(messages[0], dict):
for msg in messages:
msg = dict(msg)
content = msg["content"]
if "url" not in content:
continue
is_data = len(re.findall("^data:image/.+;base64,", content["url"])) > 0
if is_data:
encoded_str = re.sub("^data:image/.+;base64,", "", content["url"])
images.append(Image.open(BytesIO(base64.b64decode(encoded_str))))
else:
images.append(Image.open(requests.get(content["url"], stream=True).raw))
for msg in messages:
msg = dict(msg)
content = msg["content"]
if "url" not in content:
continue
is_data = len(re.findall("^data:image/.+;base64,", content["url"])) > 0
if is_data:
encoded_str = re.sub("^data:image/.+;base64,", "", content["url"])
images.append(Image.open(BytesIO(base64.b64decode(encoded_str))))
else:
images.append(Image.open(requests.get(content["url"], stream=True).raw))

return images
16 changes: 9 additions & 7 deletions llm_on_ray/inference/inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,20 +115,22 @@ class ModelDescription(BaseModel):
chat_model_with_image: bool = False
chat_template: Union[str, None] = None
default_chat_template: str = (
"{{ bos_token }}"
"Below is an instruction that describes a task. Write a response that appropriately completes the request."
"{% if messages[0]['role'] == 'system' %}"
"{{ raise_exception('System role not supported') }}"
"{% endif %}"
"{% for message in messages %}"
"{% set loop_messages = messages[1:] %}"
"{% set system_message = messages[0]['content'] %}"
"{% else %}{% set loop_messages = messages %}"
"{% set system_message = false %}{% endif %}"
"{% for message in loop_messages %}"
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
"{% endif %}"
"{% if message['role'] == 'user' %}"
"{{ '### Instruction: ' + message['content'] + eos_token }}"
"{{ '### Instruction: ' + message['content'].strip() }}"
"{% elif message['role'] == 'assistant' %}"
"{{ '### Response:' + message['content'] + eos_token }}"
"{{ '### Response:' + message['content'].strip() }}"
"{% endif %}{% endfor %}"
"{{'### End \n'}}"
"{% if add_generation_prompt %}{{'### Response:\n'}}{% endif %}"
)

@validator("quantization_type")
Expand Down
1 change: 1 addition & 0 deletions llm_on_ray/inference/models/gemma-2b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ model_description:
tokenizer_name_or_path: google/gemma-2b
config:
use_auth_token: ' '
chat_template: "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"
1 change: 1 addition & 0 deletions llm_on_ray/inference/models/gpt2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ model_description:
model_id_or_path: gpt2
tokenizer_name_or_path: gpt2
gpt_base_model: true
chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + message['content'].strip() }}{% elif message['role'] == 'assistant' %}{{ message['content'].strip() + eos_token }}{% endif %}{% endfor %}"
82 changes: 6 additions & 76 deletions llm_on_ray/inference/predictor_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def preprocess_prompts(self, input: Union[str, list], tools=None, tool_choice=No
Raises:
HTTPException: If the input prompt format is invalid or not supported.
"""

print("preprocess_prompts")
if isinstance(input, str):
return input
elif isinstance(input, list):
Expand Down Expand Up @@ -344,31 +344,6 @@ def preprocess_prompts(self, input: Union[str, list], tools=None, tool_choice=No
else:
prompt = self.process_tool.get_prompt(input)
return prompt
else:
if isinstance(input, list) and input and isinstance(input[0], dict):
prompt = self.predictor.tokenizer.apply_chat_template(input, tokenize=False)
elif isinstance(input, list) and input and isinstance(input[0], list):
prompt = [
self.predictor.tokenizer.apply_chat_template(t, tokenize=False)
for t in input
]
elif isinstance(input, list) and input and isinstance(input[0], ChatMessage):
messages = []
for chat_message in input:
message = {"role": chat_message.role, "content": chat_message.content}
messages.append(message)
prompt = self.predictor.tokenizer.apply_chat_template(
messages, tokenize=False
)
elif isinstance(input, list) and input and isinstance(input[0], str):
prompt = input
elif isinstance(input, str):
prompt = input
else:
raise TypeError(
f"Unsupported type {type(input)} for text. Expected dict or list of dicts."
)
return prompt
elif prompt_format == PromptFormat.PROMPTS_FORMAT:
raise HTTPException(400, "Invalid prompt format.")
return input
Expand Down Expand Up @@ -414,63 +389,18 @@ async def openai_call(
tool_choice=None,
):
self.use_openai = True
print("openai_call")
print(input)
print(type(input))

# return prompt or list of prompts preprocessed
prompts = self.preprocess_prompts(input, tools, tool_choice)
print(prompts)
print(type(prompts))

# Handle streaming response
if streaming_response:
async for result in self.handle_streaming(prompts, config):
yield result
else:
yield await self.handle_non_streaming(prompts, config)

def _extract_messages(self, messages):
texts, images = [], []
for message in messages:
if message["role"] == "user" and isinstance(message["content"], list):
texts.append({"role": "user", "content": message["content"][0]["text"]})
images.append(
{"role": "user", "content": message["content"][1]["image_url"]["url"]}
)
else:
texts.append(message)
return texts, images

def _prepare_image(self, messages: list):
"""Prepare image from history messages."""
from PIL import Image
import requests
from io import BytesIO
import base64
import re

# prepare images
images: List = []
if isinstance(messages[0], List):
for i in range(len(messages)):
for msg in messages[i]:
msg = dict(msg)
content = msg["content"]
if "url" not in content:
continue
is_data = len(re.findall("^data:image/.+;base64,", content["url"])) > 0
if is_data:
encoded_str = re.sub("^data:image/.+;base64,", "", content["url"])
images[i].append(Image.open(BytesIO(base64.b64decode(encoded_str))))
else:
images[i].append(Image.open(requests.get(content["url"], stream=True).raw))
elif isinstance(messages[0], dict):
for msg in messages:
msg = dict(msg)
content = msg["content"]
if "url" not in content:
continue
is_data = len(re.findall("^data:image/.+;base64,", content["url"])) > 0
if is_data:
encoded_str = re.sub("^data:image/.+;base64,", "", content["url"])
images.append(Image.open(BytesIO(base64.b64decode(encoded_str))))
else:
images.append(Image.open(requests.get(content["url"], stream=True).raw))

return images
4 changes: 2 additions & 2 deletions llm_on_ray/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,13 @@ def is_cpu_without_ipex(infer_conf: InferenceConfig) -> bool:
return (not infer_conf.ipex.enabled) and infer_conf.device == DEVICE_CPU


def get_prompt_format(input: Union[List[str], List[dict], List[List[dict]], List[ChatMessage]]):
def get_prompt_format(input: Union[List[str], List[dict], List[ChatMessage]]):
chat_format = True
prompts_format = True
for item in input:
if isinstance(item, str):
chat_format = False
elif isinstance(item, dict) or isinstance(item, ChatMessage) or isinstance(item, list):
elif isinstance(item, dict) or isinstance(item, ChatMessage):
prompts_format = False
else:
chat_format = False
Expand Down

0 comments on commit 8698a17

Please sign in to comment.