-
Notifications
You must be signed in to change notification settings - Fork 870
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TorchServe quickstart chatbot example (#3003)
* TorchServe quickstart chatbot example * Added more details in Readme * lint failure * code cleanup * review comments --------- Co-authored-by: Mark Saroufim <[email protected]>
- Loading branch information
Showing
9 changed files
with
565 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
ARG BASE_IMAGE=pytorch/torchserve:latest-gpu | ||
|
||
FROM $BASE_IMAGE as server | ||
ARG BASE_IMAGE | ||
ARG EXAMPLE_DIR | ||
ARG MODEL_NAME | ||
ARG HUGGINGFACE_TOKEN | ||
|
||
USER root | ||
|
||
ENV MODEL_NAME=$MODEL_NAME | ||
|
||
RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \ | ||
apt-get update && \ | ||
apt-get install libopenmpi-dev git -y | ||
|
||
COPY $EXAMPLE_DIR/requirements.txt /home/model-server/chat_bot/requirements.txt | ||
RUN pip install -r /home/model-server/chat_bot/requirements.txt && huggingface-cli login --token $HUGGINGFACE_TOKEN | ||
|
||
COPY $EXAMPLE_DIR /home/model-server/chat_bot | ||
COPY $EXAMPLE_DIR/dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh | ||
COPY $EXAMPLE_DIR/config.properties /home/model-server/config.properties | ||
|
||
WORKDIR /home/model-server/chat_bot | ||
RUN chmod +x /usr/local/bin/dockerd-entrypoint.sh \ | ||
&& chown -R model-server /home/model-server |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#!/bin/bash | ||
|
||
# Check if there are enough arguments | ||
if [ "$#" -eq 0 ] || [ "$#" -gt 1 ]; then | ||
echo "Usage: $0 <HF Model>" | ||
exit 1 | ||
fi | ||
|
||
MODEL_NAME=$(echo "$1" | sed 's/\//---/g') | ||
echo "Model: " $MODEL_NAME | ||
|
||
BASE_IMAGE="pytorch/torchserve:latest-cpu" | ||
|
||
DOCKER_TAG="pytorch/torchserve:${MODEL_NAME}" | ||
|
||
# Get relative path of example dir | ||
EXAMPLE_DIR=$(dirname "$(readlink -f "$0")") | ||
ROOT_DIR=${EXAMPLE_DIR}/../../../../.. | ||
ROOT_DIR=$(realpath "$ROOT_DIR") | ||
EXAMPLE_DIR=$(echo "$EXAMPLE_DIR" | sed "s|$ROOT_DIR|./|") | ||
|
||
# Build docker image for the application | ||
DOCKER_BUILDKIT=1 docker buildx build --platform=linux/amd64 --file ${EXAMPLE_DIR}/Dockerfile --build-arg BASE_IMAGE="${BASE_IMAGE}" --build-arg EXAMPLE_DIR="${EXAMPLE_DIR}" --build-arg MODEL_NAME="${MODEL_NAME}" --build-arg HUGGINGFACE_TOKEN -t "${DOCKER_TAG}" . | ||
|
||
echo "Run the following command to start the chat bot" | ||
echo "" | ||
echo docker run --rm -it --platform linux/amd64 -p 127.0.0.1:8080:8080 -p 127.0.0.1:8081:8081 -p 127.0.0.1:8082:8082 -p 127.0.0.1:8084:8084 -p 127.0.0.1:8085:8085 -v $(pwd)/model_store_1:/home/model-server/model-store $DOCKER_TAG | ||
echo "" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import json | ||
import os | ||
from concurrent.futures import ThreadPoolExecutor | ||
|
||
import requests | ||
import streamlit as st | ||
|
||
MODEL_NAME = os.environ["MODEL_NAME"] | ||
|
||
# App title | ||
st.set_page_config(page_title="TorchServe Chatbot") | ||
|
||
with st.sidebar: | ||
st.title("TorchServe Chatbot") | ||
|
||
st.session_state.model_loaded = False | ||
try: | ||
res = requests.get(url="http://localhost:8080/ping") | ||
res = requests.get(url=f"http://localhost:8081/models/{MODEL_NAME}") | ||
status = "NOT READY" | ||
if res.status_code == 200: | ||
status = json.loads(res.text)[0]["workers"][0]["status"] | ||
|
||
if status == "READY": | ||
st.session_state.model_loaded = True | ||
st.success("Proceed to entering your prompt message!", icon="👉") | ||
else: | ||
st.warning("Model not loaded in TorchServe", icon="⚠️") | ||
|
||
except requests.ConnectionError: | ||
st.warning("TorchServe is not up. Try again", icon="⚠️") | ||
|
||
if st.session_state.model_loaded: | ||
st.success(f"Model loaded: {MODEL_NAME}!", icon="👉") | ||
|
||
st.subheader("Model parameters") | ||
temperature = st.sidebar.slider( | ||
"temperature", min_value=0.1, max_value=1.0, value=0.5, step=0.1 | ||
) | ||
top_p = st.sidebar.slider( | ||
"top_p", min_value=0.1, max_value=1.0, value=0.5, step=0.1 | ||
) | ||
max_new_tokens = st.sidebar.slider( | ||
"max_new_tokens", min_value=48, max_value=512, value=50, step=4 | ||
) | ||
concurrent_requests = st.sidebar.select_slider( | ||
"concurrent_requests", options=[2**j for j in range(0, 8)] | ||
) | ||
|
||
# Store LLM generated responses | ||
if "messages" not in st.session_state.keys(): | ||
st.session_state.messages = [ | ||
{"role": "assistant", "content": "How may I assist you today?"} | ||
] | ||
|
||
# Display or clear chat messages | ||
for message in st.session_state.messages: | ||
with st.chat_message(message["role"]): | ||
st.write(message["content"]) | ||
|
||
|
||
def clear_chat_history(): | ||
st.session_state.messages = [ | ||
{"role": "assistant", "content": "How may I assist you today?"} | ||
] | ||
|
||
|
||
st.sidebar.button("Clear Chat History", on_click=clear_chat_history) | ||
|
||
|
||
def generate_model_response(prompt_input, executor): | ||
string_dialogue = ( | ||
"Question: What are the names of the planets in the solar system? Answer: " | ||
) | ||
headers = {"Content-type": "application/json", "Accept": "text/plain"} | ||
url = f"http://127.0.0.1:8080/predictions/{MODEL_NAME}" | ||
data = json.dumps( | ||
{ | ||
"prompt": prompt_input, | ||
"params": { | ||
"max_new_tokens": max_new_tokens, | ||
"top_p": top_p, | ||
"temperature": temperature, | ||
}, | ||
} | ||
) | ||
res = [ | ||
executor.submit(requests.post, url=url, data=data, headers=headers, stream=True) | ||
for i in range(concurrent_requests) | ||
] | ||
|
||
return res, max_new_tokens | ||
|
||
|
||
# User-provided prompt | ||
if prompt := st.chat_input(): | ||
st.session_state.messages.append({"role": "user", "content": prompt}) | ||
with st.chat_message("user"): | ||
st.write(prompt) | ||
|
||
# Generate a new response if last message is not from assistant | ||
if st.session_state.messages[-1]["role"] != "assistant": | ||
with st.chat_message("assistant"): | ||
with st.spinner("Thinking..."): | ||
with ThreadPoolExecutor() as executor: | ||
futures, max_tokens = generate_model_response(prompt, executor) | ||
placeholder = st.empty() | ||
full_response = "" | ||
count = 0 | ||
for future in futures: | ||
response = future.result() | ||
for chunk in response.iter_content(chunk_size=None): | ||
if chunk: | ||
data = chunk.decode("utf-8") | ||
full_response += data | ||
placeholder.markdown(full_response) | ||
message = {"role": "assistant", "content": full_response} | ||
st.session_state.messages.append(message) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
metrics_mode=prometheus | ||
model_metrics_auto_detect=true | ||
inference_address=http://0.0.0.0:8080 | ||
management_address=http://0.0.0.0:8081 | ||
metrics_address=http://0.0.0.0:8082 | ||
number_of_netty_threads=32 | ||
job_queue_size=1000 | ||
model_store=/home/model-server/model-store | ||
workflow_store=/home/model-server/wf-store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
#!/bin/bash | ||
set -e | ||
|
||
export LLAMA2_Q4_MODEL=/home/model-server/model-store/$MODEL_NAME/model/ggml-model-q4_0.gguf | ||
|
||
|
||
create_model_cfg_yaml() { | ||
# Define the YAML content with a placeholder for the model name | ||
yaml_content="# TorchServe frontend parameters\nminWorkers: 1\nmaxWorkers: 1\nresponseTimeout: 1200\n#deviceType: \"gpu\"\n#deviceIds: [0,1]\n#torchrun:\n# nproc-per-node: 1\n\nhandler:\n model_name: \"${2}\"\n manual_seed: 40" | ||
|
||
# Create the YAML file with the specified model name | ||
echo -e "$yaml_content" > "model-config-${1}.yaml" | ||
} | ||
|
||
create_model_archive() { | ||
MODEL_NAME=$1 | ||
MODEL_CFG=$2 | ||
echo "Create model archive for ${MODEL_NAME} if it doesn't already exist" | ||
if [ -d "/home/model-server/model-store/$MODEL_NAME" ]; then | ||
echo "Model archive for $MODEL_NAME exists." | ||
fi | ||
if [ -d "/home/model-server/model-store/$MODEL_NAME/model" ]; then | ||
echo "Model already download" | ||
mv /home/model-server/model-store/$MODEL_NAME/model /home/model-server/model-store/ | ||
else | ||
echo "Model needs to be downloaded" | ||
fi | ||
torch-model-archiver --model-name "$MODEL_NAME" --version 1.0 --handler llama_cpp_handler.py --config-file $MODEL_CFG -r requirements.txt --archive-format no-archive --export-path /home/model-server/model-store -f | ||
if [ -d "/home/model-server/model-store/model" ]; then | ||
mv /home/model-server/model-store/model /home/model-server/model-store/$MODEL_NAME/ | ||
fi | ||
} | ||
|
||
download_model() { | ||
MODEL_NAME=$1 | ||
HF_MODEL_NAME=$2 | ||
if [ -d "/home/model-server/model-store/$MODEL_NAME/model" ]; then | ||
echo "Model $HF_MODEL_NAME already downloaded" | ||
else | ||
echo "Downloading model $HF_MODEL_NAME" | ||
python Download_model.py --model_path /home/model-server/model-store/$MODEL_NAME/model --model_name $HF_MODEL_NAME | ||
fi | ||
} | ||
|
||
quantize_model() { | ||
if [ ! -f "$LLAMA2_Q4_MODEL" ]; then | ||
tmp_model_name=$(echo "$MODEL_NAME" | sed 's/---/--/g') | ||
directory_path=/home/model-server/model-store/$MODEL_NAME/model/models--$tmp_model_name/snapshots/ | ||
HF_MODEL_SNAPSHOT=$(find $directory_path -type d -mindepth 1) | ||
echo "Cleaning up previous build of llama-cpp" | ||
git clone https://github.com/ggerganov/llama.cpp.git build | ||
cd build | ||
make | ||
python -m pip install -r requirements.txt | ||
|
||
echo "Convert the 7B model to ggml FP16 format" | ||
python convert.py $HF_MODEL_SNAPSHOT --outfile ggml-model-f16.gguf | ||
|
||
echo "Quantize the model to 4-bits (using q4_0 method)" | ||
./quantize ggml-model-f16.gguf $LLAMA2_Q4_MODEL q4_0 | ||
|
||
cd .. | ||
echo "Saved quantized model weights to $LLAMA2_Q4_MODEL" | ||
fi | ||
} | ||
|
||
HF_MODEL_NAME=$(echo "$MODEL_NAME" | sed 's/---/\//g') | ||
if [[ "$1" = "serve" ]]; then | ||
shift 1 | ||
create_model_cfg_yaml $MODEL_NAME $HF_MODEL_NAME | ||
create_model_archive $MODEL_NAME "model-config-$MODEL_NAME.yaml" | ||
download_model $MODEL_NAME $HF_MODEL_NAME | ||
quantize_model | ||
streamlit run torchserve_server_app.py --server.port 8084 & | ||
streamlit run client_app.py --server.port 8085 | ||
else | ||
eval "$@" | ||
fi | ||
|
||
# prevent docker exit | ||
tail -f /dev/null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import logging | ||
import os | ||
from abc import ABC | ||
|
||
import torch | ||
from llama_cpp import Llama | ||
|
||
from ts.protocol.otf_message_handler import send_intermediate_predict_response | ||
from ts.torch_handler.base_handler import BaseHandler | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class LlamaCppHandler(BaseHandler, ABC): | ||
def __init__(self): | ||
super(LlamaCppHandler, self).__init__() | ||
self.initialized = False | ||
|
||
def initialize(self, ctx): | ||
"""In this initialize function, the HF large model is loaded and | ||
partitioned using DeepSpeed. | ||
Args: | ||
ctx (context): It is a JSON Object containing information | ||
pertaining to the model artifacts parameters. | ||
""" | ||
model_path = os.environ["LLAMA2_Q4_MODEL"] | ||
model_name = ctx.model_yaml_config["handler"]["model_name"] | ||
seed = int(ctx.model_yaml_config["handler"]["manual_seed"]) | ||
torch.manual_seed(seed) | ||
|
||
self.model = Llama(model_path=model_path) | ||
logger.info(f"Loaded {model_name} model successfully") | ||
|
||
def preprocess(self, data): | ||
assert ( | ||
len(data) == 1 | ||
), "llama-cpp-python is currently only supported with batch_size=1" | ||
for row in data: | ||
item = row.get("body") | ||
return item | ||
|
||
def inference(self, data): | ||
params = data["params"] | ||
tokens = self.model.tokenize(bytes(data["prompt"], "utf-8")) | ||
generation_kwargs = dict( | ||
tokens=tokens, | ||
temp=params["temperature"], | ||
top_p=params["top_p"], | ||
) | ||
count = 0 | ||
for token in self.model.generate(**generation_kwargs): | ||
if count >= params["max_new_tokens"]: | ||
break | ||
|
||
count += 1 | ||
new_text = self.model.detokenize([token]) | ||
send_intermediate_predict_response( | ||
[new_text], | ||
self.context.request_ids, | ||
"Intermediate Prediction success", | ||
200, | ||
self.context, | ||
) | ||
return [""] | ||
|
||
def postprocess(self, output): | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
transformers | ||
llama-cpp-python | ||
streamlit>=1.26.0 | ||
requests_futures |
Oops, something went wrong.