Skip to content

Commit

Permalink
Add vllm Predictor (intel#20)
Browse files Browse the repository at this point in the history
* add vllm_predictor

* add tests skeleton

* add tests skeleton

* add pytest.ini

* wip

* complete, debug wip

* nit

* nit

* nit

* complete generate supporting str and List[str]

* add model

* add streaming

* remove tests

* Add install-vllm-cpu script

* nit

Signed-off-by: Wu, Xiaochang <[email protected]>

* nit

Signed-off-by: Wu, Xiaochang <[email protected]>

* nit

* fix package inference

* update install script and add doc

* nit

* nit

* nit

* add dtype support

* nit

* nit

* nit

* add ci

* nit

* nit

* add libpthread-stubs0-dev

* fix install-vllm-cpu

* fix

* revert inference.inference_config

* debug ci

* debug ci

* debug ci

* debug ci

* debug ci

* debug ci

* debug ci

* debug ci

* update

---------

Signed-off-by: Wu, Xiaochang <[email protected]>
  • Loading branch information
xwu99 authored Jan 18, 2024
1 parent 906e597 commit e6494c0
Show file tree
Hide file tree
Showing 15 changed files with 309 additions and 28 deletions.
27 changes: 27 additions & 0 deletions .github/workflows/config/llama-2-7b-chat-hf-vllm-fp32.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
port: 8000
name: llama-2-7b-chat-hf-vllm
route_prefix: /llama-2-7b-chat-hf-vllm
cpus_per_worker: 24
gpus_per_worker: 0
deepspeed: false
vllm:
enabled: true
precision: fp32
workers_per_group: 2
device: "cpu"
ipex:
enabled: false
precision: bf16
model_description:
model_id_or_path: meta-llama/Llama-2-7b-chat-hf
tokenizer_name_or_path: meta-llama/Llama-2-7b-chat-hf
chat_processor: ChatModelLLama
prompt:
intro: ''
human_id: '[INST] {msg} [/INST]
'
bot_id: ''
stop_words: []
config:
use_auth_token: ''
27 changes: 18 additions & 9 deletions .github/workflows/workflow_inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
name: inference test
strategy:
matrix:
model: [ gpt-j-6b, gpt2, bloom-560m, opt-125m, mpt-7b, mistral-7b-v0.1, mpt-7b-bigdl, neural-chat-7b-v3-1, CodeLlama-7b-hf, falcon-7b ]
model: [ gpt-j-6b, gpt2, bloom-560m, opt-125m, mpt-7b, mistral-7b-v0.1, mpt-7b-bigdl, neural-chat-7b-v3-1, CodeLlama-7b-hf, falcon-7b, llama-2-7b-chat-hf-vllm ]
isPR:
- ${{inputs.ci_type == 'pr'}}

Expand All @@ -45,6 +45,7 @@ jobs:
- { model: "gpt-j-6b"}
- { model: "mistral-7b-v0.1"}
- { model: "mpt-7b-bigdl"}
- { model: "llama-2-7b-chat-hf-vllm"}
- dtuner_model: nathan0/mpt-7b-deltatuner-model
model: mpt-7b

Expand All @@ -64,13 +65,15 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v2

- name: Determine Target
id: "target"
run: |
target="inference"
if [[ ${{ matrix.model }} == "mpt-7b-bigdl" ]]; then
target="${target}_bigdl_cpu"
elif [[ ${{ matrix.model }} == "llama-2-7b-chat-hf-vllm" ]]; then
target="${target}_vllm"
fi
echo "target is ${target}"
echo "target=$target" >> $GITHUB_OUTPUT
Expand All @@ -79,6 +82,8 @@ jobs:
run: |
if [[ ${{ matrix.model }} == "mpt-7b-bigdl" ]]; then
DF_SUFFIX=".bigdl-cpu"
elif [[ ${{ matrix.model }} == "llama-2-7b-chat-hf-vllm" ]]; then
DF_SUFFIX=".vllm"
else
DF_SUFFIX=".cpu_and_deepspeed"
fi
Expand Down Expand Up @@ -106,12 +111,16 @@ jobs:
TARGET=${{steps.target.outputs.target}}
if [[ ${{ matrix.model }} == "mpt-7b-bigdl" ]]; then
docker exec "${TARGET}" bash -c "python inference/serve.py --config_file inference/models/bigdl/mpt-7b-bigdl.yaml --simple"
elif [[ ${{ matrix.model }} == "llama-2-7b-chat-hf-vllm" ]]; then
docker exec "${TARGET}" bash -c "python inference/serve.py --config_file .github/workflows/config/llama-2-7b-chat-hf-vllm-fp32.yaml --simple"
else
docker exec "${TARGET}" bash -c "python inference/serve.py --simple --models ${{ matrix.model }}"
fi
echo Non-streaming query:
docker exec "${TARGET}" bash -c "python examples/inference/api_server_simple/query_single.py --model_endpoint http://127.0.0.1:8000/${{ matrix.model }}"
echo Streaming query:
docker exec "${TARGET}" bash -c "python examples/inference/api_server_simple/query_single.py --model_endpoint http://127.0.0.1:8000/${{ matrix.model }} --streaming_response"
- name: Run Inference Test with Deltatuner
if: ${{ matrix.dtuner_model }}
run: |
Expand All @@ -125,7 +134,7 @@ jobs:
TARGET=${{steps.target.outputs.target}}
if [[ ${{ matrix.model }} =~ ^(gpt2|falcon-7b|mpt-7b.*)$ ]]; then
echo ${{ matrix.model }} is not supported!
else
elif [[ ! ${{ matrix.model }} == "llama-2-7b-chat-hf-vllm" ]]; then
docker exec "${TARGET}" bash -c "python .github/workflows/config/update_inference_config.py --config_file inference/models/\"${{ matrix.model }}\".yaml --output_file \"${{ matrix.model }}\".yaml.deepspeed --deepspeed"
docker exec "${TARGET}" bash -c "python inference/serve.py --config_file \"${{ matrix.model }}\".yaml.deepspeed --simple"
docker exec "${TARGET}" bash -c "python examples/inference/api_server_simple/query_single.py --model_endpoint http://127.0.0.1:8000/${{ matrix.model }}"
Expand All @@ -143,16 +152,16 @@ jobs:
docker exec "${TARGET}" bash -c "python examples/inference/api_server_simple/query_single.py --model_endpoint http://127.0.0.1:8000/${{ matrix.model }}"
docker exec "${TARGET}" bash -c "python examples/inference/api_server_simple/query_single.py --model_endpoint http://127.0.0.1:8000/${{ matrix.model }} --streaming_response"
fi
- name: Run Inference Test with REST API
run: |
TARGET=${{steps.target.outputs.target}}
if [[ ${{ matrix.model }} == "mpt-7b-bigdl" ]]; then
docker exec "${TARGET}" bash -c "python inference/serve.py --config_file inference/models/bigdl/mpt-7b-bigdl.yaml"
else
elif [[ ! ${{ matrix.model }} == "llama-2-7b-chat-hf-vllm" ]]; then
docker exec "${TARGET}" bash -c "python inference/serve.py --models ${{ matrix.model }}"
docker exec "${TARGET}" bash -c "python examples/inference/api_server_openai/query_http_requests.py --model_name ${{ matrix.model }}"
fi
docker exec "${TARGET}" bash -c "python examples/inference/api_server_openai/query_http_requests.py --model_name ${{ matrix.model }}"
- name: Stop Ray
run: |
Expand All @@ -161,7 +170,7 @@ jobs:
if [[ ! -z "$cid" ]]; then
docker exec "${TARGET}" bash -c "ray stop"
fi
- name: Stop Container
if: success() || failure()
run: |
Expand All @@ -173,4 +182,4 @@ jobs:
run: echo "to be continued"



2 changes: 1 addition & 1 deletion .github/workflows/workflow_orders_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:

call-lint:
uses: ./.github/workflows/workflow_lint.yml

call-tests:
needs: call-lint
uses: ./.github/workflows/workflow_tests.yml
Expand Down
42 changes: 42 additions & 0 deletions dev/docker/Dockerfile.vllm
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# syntax=docker/dockerfile:1
FROM ubuntu:22.04

ENV LANG C.UTF-8

WORKDIR /root/llm-on-ray

RUN --mount=type=cache,target=/var/cache/apt apt-get update -y \
&& apt-get install -y build-essential cmake wget curl git vim htop ssh net-tools \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*

ENV CONDA_DIR /opt/conda
RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \
/bin/bash ~/miniconda.sh -b -p /opt/conda
ENV PATH $CONDA_DIR/bin:$PATH

# setup env
SHELL ["/bin/bash", "--login", "-c"]

RUN --mount=type=cache,target=/opt/conda/pkgs conda init bash && \
unset -f conda && \
export PATH=$CONDA_DIR/bin/:${PATH} && \
conda config --add channels intel && \
conda install -y -c conda-forge python==3.9 gxx=12.3 gxx_linux-64=12.3

COPY ./pyproject.toml .
COPY ./dev/scripts/install-vllm-cpu.sh .

RUN mkdir ./finetune && mkdir ./inference

RUN --mount=type=cache,target=/root/.cache/pip pip install -e .[cpu] -f https://developer.intel.com/ipex-whl-stable-cpu \
-f https://download.pytorch.org/whl/torch_stable.html

# Install vllm-cpu
# Activate base first for loading g++ envs ($CONDA_PREFIX/etc/conda/activate.d/*)
RUN --mount=type=cache,target=/root/.cache/pip \
source /opt/conda/bin/activate base && ./install-vllm-cpu.sh

# TODO: workaround, remove this when fixed in vllm-cpu upstream
RUN --mount=type=cache,target=/root/.cache/pip \
pip install xformers
20 changes: 20 additions & 0 deletions dev/scripts/install-vllm-cpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/usr/bin/env bash

# Check tools
[[ -n $(which g++) ]] || { echo "GNU C++ Compiler (g++) is not found!"; exit 1; }
[[ -n $(which pip) ]] || { echo "pip command is not found!"; exit 1; }

# g++ version should be >=12.3
version_greater_equal()
{
printf '%s\n%s\n' "$2" "$1" | sort --check=quiet --version-sort
}
gcc_version=$(g++ -dumpversion)
echo
echo Current GNU C++ Compiler version: $gcc_version
echo
version_greater_equal "${gcc_version}" 12.3.0 || { echo "GNU C++ Compiler 12.3.0 or above is required!"; exit 1; }

# Install from source
MAX_JOBS=8 pip install -v git+https://github.com/bigPYJ1151/vllm@PR_Branch \
-f https://download.pytorch.org/whl/torch_stable.html
9 changes: 6 additions & 3 deletions dev/scripts/start-ray-cluster.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
set -eo pipefail

# Setup oneapi envs before starting Ray
source /opt/intel/oneapi/setvars.sh

export CCL_ZE_IPC_EXCHANGE=sockets
if [[ -e "/opt/intel/oneapi/setvars.sh" ]]; then
source /opt/intel/oneapi/setvars.sh
export CCL_ZE_IPC_EXCHANGE=sockets
else
echo "/opt/intel/oneapi/setvars.sh doesn't exist, not loading."
fi

# Setup Ray cluster
RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1 ray start --head --node-ip-address 127.0.0.1 --ray-debugger-external
Expand Down
43 changes: 43 additions & 0 deletions docs/vllm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Setting up vLLM For Intel CPU

__NOTICE: The support for vLLM is experimental and subject to change.__

## Install vLLM for Intel CPU

vLLM for CPU currently supports Intel® 4th Gen Xeon® Scalable Performance processor (formerly codenamed Sapphire Rapids) for best performance and is runnable with FP32 precision for other Xeon processors.

Please run the following script to install vLLM for CPU into your current environment. Currently a GNU C++ compiler with >=12.3 version is required to build and install.

```bash
$ dev/scripts/install-vllm-cpu.sh
```

## Setup

Please follow [Deploying and Serving LLMs on Intel CPU/GPU/Gaudi](serve.md) document to setup other environments.

## Run

#### Serving

To serve model with vLLM, run the following:

```bash
$ python serve.py --config_file inference/models/vllm/llama-2-7b-chat-hf-vllm.yaml --simple --keep_serve_terminal
```

In the above example, `vllm` property is set to `true` in the config file for enabling vLLM.

#### Querying

To start a non-streaming query, run the following:

```bash
$ python examples/inference/api_server_simple/query_single.py --model_endpoint http://127.0.0.1:8000/llama-2-7b-chat-hf
```

To start a streaming query, run the following:

```bash
$ python examples/inference/api_server_simple/query_single.py --model_endpoint http://127.0.0.1:8000/llama-2-7b-chat-hf --streaming_response
```
4 changes: 2 additions & 2 deletions inference/deepspeed_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
InferenceConfig,
DEVICE_CPU,
DEVICE_XPU,
IPEX_PRECISION_BF16,
PRECISION_BF16,
)


Expand Down Expand Up @@ -139,7 +139,7 @@ def init_model(self, local_rank: int):
pipe.model = ipex.optimize_transformers(
pipe.model.eval(),
dtype=torch.bfloat16
if self.infer_conf.ipex.precision == IPEX_PRECISION_BF16
if self.infer_conf.ipex.precision == PRECISION_BF16
else torch.float32,
inplace=True,
)
Expand Down
18 changes: 15 additions & 3 deletions inference/inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from pydantic_yaml import parse_yaml_raw_as
from typing import List, Dict, Union

IPEX_PRECISION_BF16 = "bf16"
IPEX_PRECISION_FP32 = "fp32"
PRECISION_BF16 = "bf16"
PRECISION_FP32 = "fp32"

DEVICE_CPU = "cpu"
DEVICE_HPU = "hpu"
Expand Down Expand Up @@ -32,7 +32,18 @@ class Ipex(BaseModel):
@validator("precision")
def _check_precision(cls, v: str):
if v:
assert v in [IPEX_PRECISION_BF16, IPEX_PRECISION_FP32]
assert v in [PRECISION_BF16, PRECISION_FP32]
return v


class Vllm(BaseModel):
enabled: bool = False
precision: str = "bf16"

@validator("precision")
def _check_precision(cls, v: str):
if v:
assert v in [PRECISION_BF16, PRECISION_FP32]
return v


Expand Down Expand Up @@ -89,6 +100,7 @@ class InferenceConfig(BaseModel):
gpus_per_worker: int = 0
hpus_per_worker: int = 0
deepspeed: bool = False
vllm: Vllm = Vllm()
workers_per_group: int = 2
device: str = DEVICE_CPU
ipex: Ipex = Ipex()
Expand Down
27 changes: 27 additions & 0 deletions inference/models/vllm/llama-2-7b-chat-hf-vllm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
port: 8000
name: llama-2-7b-chat-hf
route_prefix: /llama-2-7b-chat-hf
cpus_per_worker: 24
gpus_per_worker: 0
deepspeed: false
vllm:
enabled: true
precision: bf16
workers_per_group: 2
device: "cpu"
ipex:
enabled: false
precision: bf16
model_description:
model_id_or_path: meta-llama/Llama-2-7b-chat-hf
tokenizer_name_or_path: meta-llama/Llama-2-7b-chat-hf
chat_processor: ChatModelLLama
prompt:
intro: ''
human_id: '[INST] {msg} [/INST]
'
bot_id: ''
stop_words: []
config:
use_auth_token: ''
14 changes: 12 additions & 2 deletions inference/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from transformers import AutoTokenizer, StoppingCriteriaList
from inference.inference_config import InferenceConfig
from utils import StoppingCriteriaSub
from typing import List, AsyncGenerator, Union


class Predictor:
Expand Down Expand Up @@ -72,11 +73,20 @@ def configure_tokenizer(self, model_name):
tokenizer.pad_token = tokenizer.eos_token
model.generation_config.pad_token_id = model.generation_config.eos_token_id

def generate(self, prompt, **config):
def generate(self, prompts: Union[str, List[str]], **config) -> Union[str, List[str]]:
pass

def streaming_generate(self, prompt, streamer, **config):
async def generate_async(
self, prompts: Union[str, List[str]], **config
) -> Union[str, List[str]]:
pass

# output is streamed into streamer
def streaming_generate(self, prompt: str, streamer, **config) -> None:
pass

def get_streamer(self):
pass

async def stream_results(self, results_generator) -> AsyncGenerator[str, None]:
pass
Loading

0 comments on commit e6494c0

Please sign in to comment.