From e6494c055525f2569e776dfd8870865b555d60f8 Mon Sep 17 00:00:00 2001 From: Xiaochang Wu Date: Thu, 18 Jan 2024 15:43:11 +0800 Subject: [PATCH] Add vllm Predictor (#20) * add vllm_predictor * add tests skeleton * add tests skeleton * add pytest.ini * wip * complete, debug wip * nit * nit * nit * complete generate supporting str and List[str] * add model * add streaming * remove tests * Add install-vllm-cpu script * nit Signed-off-by: Wu, Xiaochang * nit Signed-off-by: Wu, Xiaochang * nit * fix package inference * update install script and add doc * nit * nit * nit * add dtype support * nit * nit * nit * add ci * nit * nit * add libpthread-stubs0-dev * fix install-vllm-cpu * fix * revert inference.inference_config * debug ci * debug ci * debug ci * debug ci * debug ci * debug ci * debug ci * debug ci * update --------- Signed-off-by: Wu, Xiaochang --- .../config/llama-2-7b-chat-hf-vllm-fp32.yaml | 27 ++++++++ .github/workflows/workflow_inference.yml | 27 +++++--- .github/workflows/workflow_orders_on_pr.yml | 2 +- dev/docker/Dockerfile.vllm | 42 +++++++++++ dev/scripts/install-vllm-cpu.sh | 20 ++++++ dev/scripts/start-ray-cluster.sh | 9 ++- docs/vllm.md | 43 ++++++++++++ inference/deepspeed_predictor.py | 4 +- inference/inference_config.py | 18 ++++- .../models/vllm/llama-2-7b-chat-hf-vllm.yaml | 27 ++++++++ inference/predictor.py | 14 +++- inference/predictor_deployment.py | 28 ++++++-- inference/transformer_predictor.py | 5 +- inference/vllm_predictor.py | 69 +++++++++++++++++++ ui/start_ui.py | 2 +- 15 files changed, 309 insertions(+), 28 deletions(-) create mode 100644 .github/workflows/config/llama-2-7b-chat-hf-vllm-fp32.yaml create mode 100644 dev/docker/Dockerfile.vllm create mode 100755 dev/scripts/install-vllm-cpu.sh create mode 100644 docs/vllm.md create mode 100644 inference/models/vllm/llama-2-7b-chat-hf-vllm.yaml create mode 100644 inference/vllm_predictor.py diff --git a/.github/workflows/config/llama-2-7b-chat-hf-vllm-fp32.yaml b/.github/workflows/config/llama-2-7b-chat-hf-vllm-fp32.yaml new file mode 100644 index 000000000..f51fda71c --- /dev/null +++ b/.github/workflows/config/llama-2-7b-chat-hf-vllm-fp32.yaml @@ -0,0 +1,27 @@ +port: 8000 +name: llama-2-7b-chat-hf-vllm +route_prefix: /llama-2-7b-chat-hf-vllm +cpus_per_worker: 24 +gpus_per_worker: 0 +deepspeed: false +vllm: + enabled: true + precision: fp32 +workers_per_group: 2 +device: "cpu" +ipex: + enabled: false + precision: bf16 +model_description: + model_id_or_path: meta-llama/Llama-2-7b-chat-hf + tokenizer_name_or_path: meta-llama/Llama-2-7b-chat-hf + chat_processor: ChatModelLLama + prompt: + intro: '' + human_id: '[INST] {msg} [/INST] + + ' + bot_id: '' + stop_words: [] + config: + use_auth_token: '' diff --git a/.github/workflows/workflow_inference.yml b/.github/workflows/workflow_inference.yml index c7cb2fa72..f79194acf 100644 --- a/.github/workflows/workflow_inference.yml +++ b/.github/workflows/workflow_inference.yml @@ -34,7 +34,7 @@ jobs: name: inference test strategy: matrix: - model: [ gpt-j-6b, gpt2, bloom-560m, opt-125m, mpt-7b, mistral-7b-v0.1, mpt-7b-bigdl, neural-chat-7b-v3-1, CodeLlama-7b-hf, falcon-7b ] + model: [ gpt-j-6b, gpt2, bloom-560m, opt-125m, mpt-7b, mistral-7b-v0.1, mpt-7b-bigdl, neural-chat-7b-v3-1, CodeLlama-7b-hf, falcon-7b, llama-2-7b-chat-hf-vllm ] isPR: - ${{inputs.ci_type == 'pr'}} @@ -45,6 +45,7 @@ jobs: - { model: "gpt-j-6b"} - { model: "mistral-7b-v0.1"} - { model: "mpt-7b-bigdl"} + - { model: "llama-2-7b-chat-hf-vllm"} - dtuner_model: nathan0/mpt-7b-deltatuner-model model: mpt-7b @@ -64,13 +65,15 @@ jobs: steps: - name: Checkout uses: actions/checkout@v2 - + - name: Determine Target id: "target" run: | target="inference" if [[ ${{ matrix.model }} == "mpt-7b-bigdl" ]]; then target="${target}_bigdl_cpu" + elif [[ ${{ matrix.model }} == "llama-2-7b-chat-hf-vllm" ]]; then + target="${target}_vllm" fi echo "target is ${target}" echo "target=$target" >> $GITHUB_OUTPUT @@ -79,6 +82,8 @@ jobs: run: | if [[ ${{ matrix.model }} == "mpt-7b-bigdl" ]]; then DF_SUFFIX=".bigdl-cpu" + elif [[ ${{ matrix.model }} == "llama-2-7b-chat-hf-vllm" ]]; then + DF_SUFFIX=".vllm" else DF_SUFFIX=".cpu_and_deepspeed" fi @@ -106,12 +111,16 @@ jobs: TARGET=${{steps.target.outputs.target}} if [[ ${{ matrix.model }} == "mpt-7b-bigdl" ]]; then docker exec "${TARGET}" bash -c "python inference/serve.py --config_file inference/models/bigdl/mpt-7b-bigdl.yaml --simple" + elif [[ ${{ matrix.model }} == "llama-2-7b-chat-hf-vllm" ]]; then + docker exec "${TARGET}" bash -c "python inference/serve.py --config_file .github/workflows/config/llama-2-7b-chat-hf-vllm-fp32.yaml --simple" else docker exec "${TARGET}" bash -c "python inference/serve.py --simple --models ${{ matrix.model }}" fi + echo Non-streaming query: docker exec "${TARGET}" bash -c "python examples/inference/api_server_simple/query_single.py --model_endpoint http://127.0.0.1:8000/${{ matrix.model }}" + echo Streaming query: docker exec "${TARGET}" bash -c "python examples/inference/api_server_simple/query_single.py --model_endpoint http://127.0.0.1:8000/${{ matrix.model }} --streaming_response" - + - name: Run Inference Test with Deltatuner if: ${{ matrix.dtuner_model }} run: | @@ -125,7 +134,7 @@ jobs: TARGET=${{steps.target.outputs.target}} if [[ ${{ matrix.model }} =~ ^(gpt2|falcon-7b|mpt-7b.*)$ ]]; then echo ${{ matrix.model }} is not supported! - else + elif [[ ! ${{ matrix.model }} == "llama-2-7b-chat-hf-vllm" ]]; then docker exec "${TARGET}" bash -c "python .github/workflows/config/update_inference_config.py --config_file inference/models/\"${{ matrix.model }}\".yaml --output_file \"${{ matrix.model }}\".yaml.deepspeed --deepspeed" docker exec "${TARGET}" bash -c "python inference/serve.py --config_file \"${{ matrix.model }}\".yaml.deepspeed --simple" docker exec "${TARGET}" bash -c "python examples/inference/api_server_simple/query_single.py --model_endpoint http://127.0.0.1:8000/${{ matrix.model }}" @@ -143,16 +152,16 @@ jobs: docker exec "${TARGET}" bash -c "python examples/inference/api_server_simple/query_single.py --model_endpoint http://127.0.0.1:8000/${{ matrix.model }}" docker exec "${TARGET}" bash -c "python examples/inference/api_server_simple/query_single.py --model_endpoint http://127.0.0.1:8000/${{ matrix.model }} --streaming_response" fi - + - name: Run Inference Test with REST API run: | TARGET=${{steps.target.outputs.target}} if [[ ${{ matrix.model }} == "mpt-7b-bigdl" ]]; then docker exec "${TARGET}" bash -c "python inference/serve.py --config_file inference/models/bigdl/mpt-7b-bigdl.yaml" - else + elif [[ ! ${{ matrix.model }} == "llama-2-7b-chat-hf-vllm" ]]; then docker exec "${TARGET}" bash -c "python inference/serve.py --models ${{ matrix.model }}" + docker exec "${TARGET}" bash -c "python examples/inference/api_server_openai/query_http_requests.py --model_name ${{ matrix.model }}" fi - docker exec "${TARGET}" bash -c "python examples/inference/api_server_openai/query_http_requests.py --model_name ${{ matrix.model }}" - name: Stop Ray run: | @@ -161,7 +170,7 @@ jobs: if [[ ! -z "$cid" ]]; then docker exec "${TARGET}" bash -c "ray stop" fi - + - name: Stop Container if: success() || failure() run: | @@ -173,4 +182,4 @@ jobs: run: echo "to be continued" - + diff --git a/.github/workflows/workflow_orders_on_pr.yml b/.github/workflows/workflow_orders_on_pr.yml index 30b401047..0fdb9bb01 100644 --- a/.github/workflows/workflow_orders_on_pr.yml +++ b/.github/workflows/workflow_orders_on_pr.yml @@ -20,7 +20,7 @@ jobs: call-lint: uses: ./.github/workflows/workflow_lint.yml - + call-tests: needs: call-lint uses: ./.github/workflows/workflow_tests.yml diff --git a/dev/docker/Dockerfile.vllm b/dev/docker/Dockerfile.vllm new file mode 100644 index 000000000..8e85f9e5c --- /dev/null +++ b/dev/docker/Dockerfile.vllm @@ -0,0 +1,42 @@ +# syntax=docker/dockerfile:1 +FROM ubuntu:22.04 + +ENV LANG C.UTF-8 + +WORKDIR /root/llm-on-ray + +RUN --mount=type=cache,target=/var/cache/apt apt-get update -y \ + && apt-get install -y build-essential cmake wget curl git vim htop ssh net-tools \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +ENV CONDA_DIR /opt/conda +RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \ + /bin/bash ~/miniconda.sh -b -p /opt/conda +ENV PATH $CONDA_DIR/bin:$PATH + +# setup env +SHELL ["/bin/bash", "--login", "-c"] + +RUN --mount=type=cache,target=/opt/conda/pkgs conda init bash && \ + unset -f conda && \ + export PATH=$CONDA_DIR/bin/:${PATH} && \ + conda config --add channels intel && \ + conda install -y -c conda-forge python==3.9 gxx=12.3 gxx_linux-64=12.3 + +COPY ./pyproject.toml . +COPY ./dev/scripts/install-vllm-cpu.sh . + +RUN mkdir ./finetune && mkdir ./inference + +RUN --mount=type=cache,target=/root/.cache/pip pip install -e .[cpu] -f https://developer.intel.com/ipex-whl-stable-cpu \ + -f https://download.pytorch.org/whl/torch_stable.html + +# Install vllm-cpu +# Activate base first for loading g++ envs ($CONDA_PREFIX/etc/conda/activate.d/*) +RUN --mount=type=cache,target=/root/.cache/pip \ + source /opt/conda/bin/activate base && ./install-vllm-cpu.sh + +# TODO: workaround, remove this when fixed in vllm-cpu upstream +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install xformers diff --git a/dev/scripts/install-vllm-cpu.sh b/dev/scripts/install-vllm-cpu.sh new file mode 100755 index 000000000..64b3690a4 --- /dev/null +++ b/dev/scripts/install-vllm-cpu.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +# Check tools +[[ -n $(which g++) ]] || { echo "GNU C++ Compiler (g++) is not found!"; exit 1; } +[[ -n $(which pip) ]] || { echo "pip command is not found!"; exit 1; } + +# g++ version should be >=12.3 +version_greater_equal() +{ + printf '%s\n%s\n' "$2" "$1" | sort --check=quiet --version-sort +} +gcc_version=$(g++ -dumpversion) +echo +echo Current GNU C++ Compiler version: $gcc_version +echo +version_greater_equal "${gcc_version}" 12.3.0 || { echo "GNU C++ Compiler 12.3.0 or above is required!"; exit 1; } + +# Install from source +MAX_JOBS=8 pip install -v git+https://github.com/bigPYJ1151/vllm@PR_Branch \ + -f https://download.pytorch.org/whl/torch_stable.html diff --git a/dev/scripts/start-ray-cluster.sh b/dev/scripts/start-ray-cluster.sh index 60c8b083d..59cd6f76a 100755 --- a/dev/scripts/start-ray-cluster.sh +++ b/dev/scripts/start-ray-cluster.sh @@ -3,9 +3,12 @@ set -eo pipefail # Setup oneapi envs before starting Ray -source /opt/intel/oneapi/setvars.sh - -export CCL_ZE_IPC_EXCHANGE=sockets +if [[ -e "/opt/intel/oneapi/setvars.sh" ]]; then + source /opt/intel/oneapi/setvars.sh + export CCL_ZE_IPC_EXCHANGE=sockets +else + echo "/opt/intel/oneapi/setvars.sh doesn't exist, not loading." +fi # Setup Ray cluster RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1 ray start --head --node-ip-address 127.0.0.1 --ray-debugger-external diff --git a/docs/vllm.md b/docs/vllm.md new file mode 100644 index 000000000..58393a9ae --- /dev/null +++ b/docs/vllm.md @@ -0,0 +1,43 @@ +# Setting up vLLM For Intel CPU + +__NOTICE: The support for vLLM is experimental and subject to change.__ + +## Install vLLM for Intel CPU + +vLLM for CPU currently supports Intel® 4th Gen Xeon® Scalable Performance processor (formerly codenamed Sapphire Rapids) for best performance and is runnable with FP32 precision for other Xeon processors. + +Please run the following script to install vLLM for CPU into your current environment. Currently a GNU C++ compiler with >=12.3 version is required to build and install. + +```bash +$ dev/scripts/install-vllm-cpu.sh +``` + +## Setup + +Please follow [Deploying and Serving LLMs on Intel CPU/GPU/Gaudi](serve.md) document to setup other environments. + +## Run + +#### Serving + +To serve model with vLLM, run the following: + +```bash +$ python serve.py --config_file inference/models/vllm/llama-2-7b-chat-hf-vllm.yaml --simple --keep_serve_terminal +``` + +In the above example, `vllm` property is set to `true` in the config file for enabling vLLM. + +#### Querying + +To start a non-streaming query, run the following: + +```bash +$ python examples/inference/api_server_simple/query_single.py --model_endpoint http://127.0.0.1:8000/llama-2-7b-chat-hf +``` + +To start a streaming query, run the following: + +```bash +$ python examples/inference/api_server_simple/query_single.py --model_endpoint http://127.0.0.1:8000/llama-2-7b-chat-hf --streaming_response +``` \ No newline at end of file diff --git a/inference/deepspeed_predictor.py b/inference/deepspeed_predictor.py index b9ec6cda9..506137ca4 100644 --- a/inference/deepspeed_predictor.py +++ b/inference/deepspeed_predictor.py @@ -18,7 +18,7 @@ InferenceConfig, DEVICE_CPU, DEVICE_XPU, - IPEX_PRECISION_BF16, + PRECISION_BF16, ) @@ -139,7 +139,7 @@ def init_model(self, local_rank: int): pipe.model = ipex.optimize_transformers( pipe.model.eval(), dtype=torch.bfloat16 - if self.infer_conf.ipex.precision == IPEX_PRECISION_BF16 + if self.infer_conf.ipex.precision == PRECISION_BF16 else torch.float32, inplace=True, ) diff --git a/inference/inference_config.py b/inference/inference_config.py index 0e9fd50f7..08dc5cf07 100644 --- a/inference/inference_config.py +++ b/inference/inference_config.py @@ -3,8 +3,8 @@ from pydantic_yaml import parse_yaml_raw_as from typing import List, Dict, Union -IPEX_PRECISION_BF16 = "bf16" -IPEX_PRECISION_FP32 = "fp32" +PRECISION_BF16 = "bf16" +PRECISION_FP32 = "fp32" DEVICE_CPU = "cpu" DEVICE_HPU = "hpu" @@ -32,7 +32,18 @@ class Ipex(BaseModel): @validator("precision") def _check_precision(cls, v: str): if v: - assert v in [IPEX_PRECISION_BF16, IPEX_PRECISION_FP32] + assert v in [PRECISION_BF16, PRECISION_FP32] + return v + + +class Vllm(BaseModel): + enabled: bool = False + precision: str = "bf16" + + @validator("precision") + def _check_precision(cls, v: str): + if v: + assert v in [PRECISION_BF16, PRECISION_FP32] return v @@ -89,6 +100,7 @@ class InferenceConfig(BaseModel): gpus_per_worker: int = 0 hpus_per_worker: int = 0 deepspeed: bool = False + vllm: Vllm = Vllm() workers_per_group: int = 2 device: str = DEVICE_CPU ipex: Ipex = Ipex() diff --git a/inference/models/vllm/llama-2-7b-chat-hf-vllm.yaml b/inference/models/vllm/llama-2-7b-chat-hf-vllm.yaml new file mode 100644 index 000000000..bc0ca2986 --- /dev/null +++ b/inference/models/vllm/llama-2-7b-chat-hf-vllm.yaml @@ -0,0 +1,27 @@ +port: 8000 +name: llama-2-7b-chat-hf +route_prefix: /llama-2-7b-chat-hf +cpus_per_worker: 24 +gpus_per_worker: 0 +deepspeed: false +vllm: + enabled: true + precision: bf16 +workers_per_group: 2 +device: "cpu" +ipex: + enabled: false + precision: bf16 +model_description: + model_id_or_path: meta-llama/Llama-2-7b-chat-hf + tokenizer_name_or_path: meta-llama/Llama-2-7b-chat-hf + chat_processor: ChatModelLLama + prompt: + intro: '' + human_id: '[INST] {msg} [/INST] + + ' + bot_id: '' + stop_words: [] + config: + use_auth_token: '' diff --git a/inference/predictor.py b/inference/predictor.py index 1965f3d59..4f7c9d3af 100644 --- a/inference/predictor.py +++ b/inference/predictor.py @@ -3,6 +3,7 @@ from transformers import AutoTokenizer, StoppingCriteriaList from inference.inference_config import InferenceConfig from utils import StoppingCriteriaSub +from typing import List, AsyncGenerator, Union class Predictor: @@ -72,11 +73,20 @@ def configure_tokenizer(self, model_name): tokenizer.pad_token = tokenizer.eos_token model.generation_config.pad_token_id = model.generation_config.eos_token_id - def generate(self, prompt, **config): + def generate(self, prompts: Union[str, List[str]], **config) -> Union[str, List[str]]: pass - def streaming_generate(self, prompt, streamer, **config): + async def generate_async( + self, prompts: Union[str, List[str]], **config + ) -> Union[str, List[str]]: + pass + + # output is streamed into streamer + def streaming_generate(self, prompt: str, streamer, **config) -> None: pass def get_streamer(self): pass + + async def stream_results(self, results_generator) -> AsyncGenerator[str, None]: + pass diff --git a/inference/predictor_deployment.py b/inference/predictor_deployment.py index fff8375d1..2828931d5 100644 --- a/inference/predictor_deployment.py +++ b/inference/predictor_deployment.py @@ -53,11 +53,17 @@ def __init__(self, infer_conf: InferenceConfig): self.process_tool = chat_processor(**prompt.dict()) self.use_deepspeed = infer_conf.deepspeed + self.use_vllm = infer_conf.vllm.enabled + if self.use_deepspeed: from deepspeed_predictor import DeepSpeedPredictor self.predictor = DeepSpeedPredictor(infer_conf) self.streamer = self.predictor.get_streamer() + elif self.use_vllm: + from vllm_predictor import VllmPredictor + + self.predictor = VllmPredictor(infer_conf) else: from transformer_predictor import TransformerPredictor @@ -94,13 +100,29 @@ async def __call__(self, http_request: Request) -> Union[StreamingResponse, str] prompts.extend(text) else: prompts.append(text) + if not streaming_response: - return self.predictor.generate(prompts, **config) + if self.use_vllm: + return await self.predictor.generate_async(prompts, **config) + else: + return self.predictor.generate(prompts, **config) + if self.use_deepspeed: self.predictor.streaming_generate(prompts, self.streamer, **config) return StreamingResponse( self.consume_streamer(), status_code=200, media_type="text/plain" ) + elif self.use_vllm: + # TODO: streaming only support single prompt + # It's a wordaround for current situation, need another PR to address this + if isinstance(prompts, list): + prompt = prompts[0] + results_generator = await self.predictor.streaming_generate_async(prompt, **config) + return StreamingResponse( + self.predictor.stream_results(results_generator), + status_code=200, + media_type="text/plain", + ) else: streamer = self.predictor.get_streamer() self.loop.run_in_executor( @@ -108,9 +130,7 @@ async def __call__(self, http_request: Request) -> Union[StreamingResponse, str] functools.partial(self.predictor.streaming_generate, prompts, streamer, **config), ) return StreamingResponse( - self.consume_streamer_async(streamer), - status_code=200, - media_type="text/plain", + self.consume_streamer_async(streamer), status_code=200, media_type="text/plain" ) async def stream_response(self, prompt, config): diff --git a/inference/transformer_predictor.py b/inference/transformer_predictor.py index 2784016b9..70e90ebe6 100644 --- a/inference/transformer_predictor.py +++ b/inference/transformer_predictor.py @@ -1,7 +1,7 @@ import torch from transformers import AutoModelForCausalLM, AutoConfig from transformers import TextIteratorStreamer -from inference.inference_config import InferenceConfig, IPEX_PRECISION_BF16 +from inference.inference_config import InferenceConfig, PRECISION_BF16 from predictor import Predictor from utils import get_torch_dtype @@ -9,7 +9,6 @@ class TransformerPredictor(Predictor): def __init__(self, infer_conf: InferenceConfig): super().__init__(infer_conf) - model_desc = infer_conf.model_description model_config = model_desc.config hf_config = AutoConfig.from_pretrained( @@ -86,7 +85,7 @@ def __init__(self, infer_conf: InferenceConfig): model = ipex.optimize_transformers( model.eval(), dtype=torch.bfloat16 - if infer_conf.ipex.precision == IPEX_PRECISION_BF16 + if infer_conf.ipex.precision == PRECISION_BF16 else torch.float32, inplace=True, ) diff --git a/inference/vllm_predictor.py b/inference/vllm_predictor.py new file mode 100644 index 000000000..6123b3906 --- /dev/null +++ b/inference/vllm_predictor.py @@ -0,0 +1,69 @@ +import asyncio +from typing import AsyncGenerator, List, Union +from predictor import Predictor +from inference.inference_config import InferenceConfig, PRECISION_BF16 +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.sampling_params import SamplingParams +from vllm.utils import random_uuid + + +class VllmPredictor(Predictor): + def __init__(self, infer_conf: InferenceConfig): + super().__init__(infer_conf) + + model_desc = infer_conf.model_description + model_config = model_desc.config + dtype = "bfloat16" if infer_conf.vllm.precision == PRECISION_BF16 else "float32" + + args = AsyncEngineArgs( + model=model_desc.model_id_or_path, + trust_remote_code=model_config.trust_remote_code, + device=infer_conf.device, + dtype=dtype, + ) + + self.engine = AsyncLLMEngine.from_engine_args(args) + + async def _get_generator_output(self, results_generator): + async for request_output in results_generator: + if request_output.finished: + return request_output.outputs[0].text + return None + + async def generate_async( + self, prompts: Union[str, List[str]], **config + ) -> Union[str, List[str]]: + sampling_params = SamplingParams(**config) + if isinstance(prompts, str): + request_id = random_uuid() + results_generator = self.engine.generate(prompts, sampling_params, request_id) + async for request_output in results_generator: + if request_output.finished: + return request_output.outputs[0].text + else: + results_generators = [ + self.engine.generate(prompt, sampling_params, random_uuid()) for prompt in prompts + ] + results = [ + self._get_generator_output(results_generator) + for results_generator in results_generators + ] + return await asyncio.gather(*results) + + return "" + + async def streaming_generate_async(self, prompt, **config): + sampling_params = SamplingParams(**config) + request_id = random_uuid() + results_generator = self.engine.generate(prompt, sampling_params, request_id) + return results_generator + + async def stream_results(self, results_generator) -> AsyncGenerator[str, None]: + num_returned = 0 + async for request_output in results_generator: + text_outputs = [output.text for output in request_output.outputs] + assert len(text_outputs) == 1 + text_output = text_outputs[0][num_returned:] + yield text_output + num_returned += len(text_output) diff --git a/ui/start_ui.py b/ui/start_ui.py index 33df94177..c9ab0e23e 100644 --- a/ui/start_ui.py +++ b/ui/start_ui.py @@ -752,7 +752,7 @@ def _init_ui(self): head_content = """
-

Manage LLM Lifecycle

+

Manage LLM Lifecycle

Fine-Tune LLMs using workflow on Ray, Deploy and Inference

"""