diff --git a/.github/workflows/config/llama-2-7b-chat-hf-vllm-fp32.yaml b/.github/workflows/config/llama-2-7b-chat-hf-vllm-fp32.yaml index d3d96a0e1..8445a8fd5 100644 --- a/.github/workflows/config/llama-2-7b-chat-hf-vllm-fp32.yaml +++ b/.github/workflows/config/llama-2-7b-chat-hf-vllm-fp32.yaml @@ -14,7 +14,7 @@ ipex: enabled: false precision: bf16 model_description: - model_id_or_path: meta-llama/Llama-2-7b-chat-hf - tokenizer_name_or_path: meta-llama/Llama-2-7b-chat-hf + model_id_or_path: NousResearch/Llama-2-7b-chat-hf + tokenizer_name_or_path: NousResearch/Llama-2-7b-chat-hf config: use_auth_token: '' diff --git a/.github/workflows/workflow_finetune.yml b/.github/workflows/workflow_finetune.yml index 8b0347de1..e993011cb 100644 --- a/.github/workflows/workflow_finetune.yml +++ b/.github/workflows/workflow_finetune.yml @@ -11,10 +11,10 @@ on: default: '10.1.2.13:5000/llmray-build' http_proxy: type: string - default: 'http://10.24.221.169:911' + default: 'http://10.24.221.169:912' https_proxy: type: string - default: 'http://10.24.221.169:911' + default: 'http://10.24.221.169:912' runner_config_path: type: string default: '/home/ci/llm-ray-actions-runner' @@ -34,7 +34,7 @@ jobs: name: finetune strategy: matrix: - model: [ EleutherAI/gpt-j-6b, meta-llama/Llama-2-7b-chat-hf, gpt2, bigscience/bloom-560m, facebook/opt-125m, mosaicml/mpt-7b, meta-llama/Llama-2-7b-hf, mistralai/Mistral-7B-v0.1, google/gemma-2b] + model: [ EleutherAI/gpt-j-6b, NousResearch/Llama-2-7b-chat-hf, gpt2, bigscience/bloom-560m, facebook/opt-125m, mosaicml/mpt-7b, NousResearch/Llama-2-7b-hf, mistralai/Mistral-7B-v0.1, google/gemma-2b] isPR: - ${{inputs.ci_type == 'pr'}} @@ -42,7 +42,7 @@ jobs: - { isPR: true } include: - { model: "EleutherAI/gpt-j-6b"} - - { model: "meta-llama/Llama-2-7b-chat-hf"} + - { model: "NousResearch/Llama-2-7b-chat-hf"} - { model: "mistralai/Mistral-7B-v0.1"} - { model: "google/gemma-2b"} @@ -65,9 +65,6 @@ jobs: - name: Checkout uses: actions/checkout@v4 - - name: Load environment variables - run: cat /root/actions-runner-config/.env >> $GITHUB_ENV - - name: Build Docker Image run: | DF_SUFFIX=".cpu_and_deepspeed" @@ -83,7 +80,7 @@ jobs: model_cache_path=${{ inputs.model_cache_path }} USE_PROXY="1" source dev/scripts/ci-functions.sh - start_docker ${TARGET} ${code_checkout_path} ${model_cache_path} ${USE_PROXY} ${{env.HF_ACCESS_TOKEN}} + start_docker ${TARGET} ${code_checkout_path} ${model_cache_path} ${USE_PROXY} - name: Run Finetune Test run: | diff --git a/.github/workflows/workflow_finetune_gpu.yml b/.github/workflows/workflow_finetune_gpu.yml index 2114b66db..37e612324 100644 --- a/.github/workflows/workflow_finetune_gpu.yml +++ b/.github/workflows/workflow_finetune_gpu.yml @@ -8,17 +8,17 @@ on: default: '10.1.2.13:5000/llmray-build' http_proxy: type: string - default: 'http://10.24.221.169:911' + default: 'http://10.24.221.169:912' https_proxy: type: string - default: 'http://10.24.221.169:911' + default: 'http://10.24.221.169:912' jobs: finetune-gpu: name: finetune-gpu strategy: matrix: - model: [ meta-llama/Llama-2-7b-chat-hf ] + model: [ NousResearch/Llama-2-7b-chat-hf ] runs-on: self-hosted defaults: diff --git a/.github/workflows/workflow_inference.yml b/.github/workflows/workflow_inference.yml index 61f458bcd..ca57affac 100644 --- a/.github/workflows/workflow_inference.yml +++ b/.github/workflows/workflow_inference.yml @@ -11,10 +11,10 @@ on: default: '10.1.2.13:5000/llmray-build' http_proxy: type: string - default: 'http://10.24.221.169:911' + default: 'http://10.24.221.169:912' https_proxy: type: string - default: 'http://10.24.221.169:911' + default: 'http://10.24.221.169:912' runner_config_path: type: string default: '/home/ci/llm-ray-actions-runner' @@ -67,9 +67,6 @@ jobs: - name: Checkout uses: actions/checkout@v4 - - name: Load environment variables - run: cat /root/actions-runner-config/.env >> $GITHUB_ENV - - name: Determine Target id: "target" run: | @@ -94,7 +91,7 @@ jobs: model_cache_path=${{ inputs.model_cache_path }} USE_PROXY="1" source dev/scripts/ci-functions.sh - start_docker ${TARGET} ${code_checkout_path} ${model_cache_path} ${USE_PROXY} ${{env.HF_ACCESS_TOKEN}} + start_docker ${TARGET} ${code_checkout_path} ${model_cache_path} ${USE_PROXY} - name: Start Ray Cluster run: | diff --git a/.github/workflows/workflow_inference_gaudi2.yml b/.github/workflows/workflow_inference_gaudi2.yml index 588e8dab0..dedeb4154 100644 --- a/.github/workflows/workflow_inference_gaudi2.yml +++ b/.github/workflows/workflow_inference_gaudi2.yml @@ -73,9 +73,6 @@ jobs: - name: Checkout uses: actions/checkout@v4 - - name: Load environment variables - run: cat /root/actions-runner-config/.env >> $GITHUB_ENV - - name: Build Docker Image run: | DF_SUFFIX=".gaudi2" @@ -98,7 +95,6 @@ jobs: cid=$(docker ps -a -q --filter "name=${TARGET}") if [[ ! -z "$cid" ]]; then docker rm $cid; fi docker run -tid --name="${TARGET}" --hostname="${TARGET}-container" --runtime=habana -v /home/yizhong/Model-References:/root/Model-References -v ${{ inputs.code_checkout_path }}:/root/llm-on-ray -v ${{ inputs.model_cache_path }}:/root/.cache/huggingface/hub/ -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --cap-add sys_ptrace --net=host --ipc=host ${TARGET}:habana - - name: Start Ray Cluster run: | TARGET=${{steps.target.outputs.target}} @@ -117,7 +113,6 @@ jobs: conf_path = "llm_on_ray/inference/models/hpu/llama-2-7b-chat-hf-vllm-hpu.yaml" with open(conf_path, encoding="utf-8") as reader: result = yaml.load(reader, Loader=yaml.FullLoader) - result['model_description']["config"]["use_auth_token"] = "${{ env.HF_ACCESS_TOKEN }}" with open(conf_path, 'w') as output: yaml.dump(result, output, sort_keys=False) EOF @@ -128,7 +123,6 @@ jobs: elif [[ ${{ matrix.model }} == "llama-2-70b-chat-hf" ]]; then docker exec "${TARGET}" bash -c "llm_on_ray-serve --config_file llm_on_ray/inference/models/hpu/llama-2-70b-chat-hf-hpu.yaml --keep_serve_terminal" elif [[ ${{ matrix.model }} == "llama-2-7b-chat-hf-vllm" ]]; then - docker exec "${TARGET}" bash -c "huggingface-cli login --token ${{ env.HF_ACCESS_TOKEN }}" docker exec "${TARGET}" bash -c "llm_on_ray-serve --config_file llm_on_ray/inference/models/hpu/llama-2-7b-chat-hf-vllm-hpu.yaml --keep_serve_terminal" fi echo Streaming query: diff --git a/.github/workflows/workflow_test_benchmark.yml b/.github/workflows/workflow_test_benchmark.yml index 2f78c997d..d737675c0 100644 --- a/.github/workflows/workflow_test_benchmark.yml +++ b/.github/workflows/workflow_test_benchmark.yml @@ -11,10 +11,10 @@ on: default: '10.1.2.13:5000/llmray-build' http_proxy: type: string - default: 'http://10.24.221.169:911' + default: 'http://10.24.221.169:912' https_proxy: type: string - default: 'http://10.24.221.169:911' + default: 'http://10.24.221.169:912' runner_config_path: type: string default: '/home/ci/llm-ray-actions-runner' @@ -92,24 +92,6 @@ jobs: TARGET=${{steps.target.outputs.target}} # Additional libraries required for pytest docker exec "${TARGET}" bash -c "pip install -r tests/requirements.txt" - CMD=$(cat << EOF - import yaml - conf_path = "llm_on_ray/inference/models/llama-2-7b-chat-hf.yaml" - with open(conf_path, encoding="utf-8") as reader: - result = yaml.load(reader, Loader=yaml.FullLoader) - result['model_description']["config"]["use_auth_token"] = "${{ env.HF_ACCESS_TOKEN }}" - with open(conf_path, 'w') as output: - yaml.dump(result, output, sort_keys=False) - conf_path = "llm_on_ray/inference/models/vllm/llama-2-7b-chat-hf-vllm.yaml" - with open(conf_path, encoding="utf-8") as reader: - result = yaml.load(reader, Loader=yaml.FullLoader) - result['model_description']["config"]["use_auth_token"] = "${{ env.HF_ACCESS_TOKEN }}" - with open(conf_path, 'w') as output: - yaml.dump(result, output, sort_keys=False) - EOF - ) - docker exec "${TARGET}" python -c "$CMD" - docker exec "${TARGET}" bash -c "huggingface-cli login --token ${{ env.HF_ACCESS_TOKEN }}" docker exec "${TARGET}" bash -c "./tests/run-tests-benchmark.sh" - name: Stop Ray run: | diff --git a/README.md b/README.md index c0967ab34..4870bab48 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,14 @@ Deploy a model on Ray and expose an endpoint for serving. This command uses GPT2 ```bash llm_on_ray-serve --config_file llm_on_ray/inference/models/gpt2.yaml ``` - +You can also use model_ids to serve directly through: +```bash +llm_on_ray-serve --models gpt2 +``` +List all support model_ids with config file path: +```bash +llm_on_ray-serve --list_model_ids +``` The default served method is to provide an OpenAI-compatible API server ([OpenAI API Reference](https://platform.openai.com/docs/api-reference/chat)), you can access and test it in many ways: ```bash # using curl diff --git a/benchmarks/run_benchmark.sh b/benchmarks/run_benchmark.sh index 7cd0d07fb..3c7118dd7 100644 --- a/benchmarks/run_benchmark.sh +++ b/benchmarks/run_benchmark.sh @@ -229,4 +229,4 @@ then fi output_tokens_length=32 get_best_latency $iter "${input_tokens_length[*]}" $output_tokens_length $benchmark_dir -fi +fi \ No newline at end of file diff --git a/dev/scripts/ci-functions.sh b/dev/scripts/ci-functions.sh index 993e71dde..ce570f941 100644 --- a/dev/scripts/ci-functions.sh +++ b/dev/scripts/ci-functions.sh @@ -1,8 +1,8 @@ #!/usr/bin/env bash set -eo pipefail -HTTP_PROXY='http://10.24.221.169:911' -HTTPS_PROXY='http://10.24.221.169:911' +HTTP_PROXY='http://10.24.221.169:912' +HTTPS_PROXY='http://10.24.221.169:912' MODEL_CACHE_PATH_LOACL='/root/.cache/huggingface/hub' CODE_CHECKOUT_PATH_LOCAL='/root/llm-on-ray' @@ -59,7 +59,6 @@ start_docker() { local code_checkout_path=$2 local model_cache_path=$3 local USE_PROXY=$4 - local HF_TOKEN=$5 cid=$(docker ps -q --filter "name=${TARGET}") if [[ ! -z "$cid" ]]; then docker stop $cid && docker rm $cid; fi @@ -86,12 +85,7 @@ start_docker() { fi echo "docker run -tid "${docker_args[@]}" "${TARGET}:latest"" - docker run -tid "${docker_args[@]}" "${TARGET}:latest" - if [ -z "$HF_TOKEN" ]; then - echo "no hf token" - else - docker exec "${TARGET}" bash -c "huggingface-cli login --token ${HF_TOKEN}" - fi + docker run -tid "${docker_args[@]}" "${TARGET}:latest" } start_docker_gaudi() { diff --git a/examples/inference/api_server_simple/query_dynamic_batch.py b/examples/inference/api_server_simple/query_dynamic_batch.py index a9e1b8837..2982a9636 100644 --- a/examples/inference/api_server_simple/query_dynamic_batch.py +++ b/examples/inference/api_server_simple/query_dynamic_batch.py @@ -18,6 +18,10 @@ import aiohttp import argparse from typing import Dict, Union +from llm_on_ray.inference.api_simple_backend.simple_protocol import ( + SimpleRequest, + SimpleModelResponse, +) parser = argparse.ArgumentParser( description="Example script to query with multiple requests", add_help=True @@ -63,9 +67,8 @@ config["top_k"] = float(args.top_k) -async def send_query(session, endpoint, prompt, config): - json_request = {"text": prompt, "config": config} - async with session.post(endpoint, json=json_request) as resp: +async def send_query(session, endpoint, req): + async with session.post(endpoint, json=req.dict()) as resp: return await resp.text() @@ -86,16 +89,15 @@ async def send_query(session, endpoint, prompt, config): configs = [config1] * 5 + [config2] * (len(prompts) - 5) +reqs = [SimpleRequest(text=prompt, config=config) for prompt, config in zip(prompts, configs)] + -async def send_all_query(endpoint, prompts, configs): +async def send_all_query(endpoint, reqs): async with aiohttp.ClientSession() as session: - tasks = [ - send_query(session, endpoint, prompt, config) - for prompt, config in zip(prompts, configs) - ] + tasks = [send_query(session, endpoint, req) for req in reqs] responses = await asyncio.gather(*tasks) print("\n--------------\n".join(responses)) print("\nTotal responses:", len(responses)) -asyncio.run(send_all_query(args.model_endpoint, prompts, configs)) +asyncio.run(send_all_query(args.model_endpoint, reqs)) diff --git a/examples/inference/api_server_simple/query_single.py b/examples/inference/api_server_simple/query_single.py index 62bb4dc45..5aabdd8ea 100644 --- a/examples/inference/api_server_simple/query_single.py +++ b/examples/inference/api_server_simple/query_single.py @@ -17,6 +17,10 @@ import requests import argparse from typing import Dict, Union +from llm_on_ray.inference.api_simple_backend.simple_protocol import ( + SimpleRequest, + SimpleModelResponse, +) parser = argparse.ArgumentParser( description="Example script to query with single request", add_help=True @@ -66,20 +70,22 @@ if args.top_k: config["top_k"] = float(args.top_k) -sample_input = {"text": prompt, "config": config, "stream": args.streaming_response} +sample_input = SimpleRequest(text=prompt, config=config, stream=args.streaming_response) proxies = {"http": None, "https": None} outputs = requests.post( args.model_endpoint, proxies=proxies, # type: ignore - json=sample_input, + json=sample_input.dict(), stream=args.streaming_response, ) outputs.raise_for_status() + +simple_response = SimpleModelResponse.from_requests_response(outputs) if args.streaming_response: - for output in outputs.iter_content(chunk_size=None, decode_unicode=True): + for output in simple_response.iter_content(chunk_size=1, decode_unicode=True): print(output, end="", flush=True) print() else: - print(outputs.text, flush=True) + print(simple_response.text, flush=True) diff --git a/llm_on_ray/inference/api_simple_backend/simple_protocol.py b/llm_on_ray/inference/api_simple_backend/simple_protocol.py new file mode 100644 index 000000000..b873763ad --- /dev/null +++ b/llm_on_ray/inference/api_simple_backend/simple_protocol.py @@ -0,0 +1,91 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Dict, Optional, Union, Iterator, List +import requests +from pydantic import BaseModel, ValidationError, validator + + +class SimpleRequest(BaseModel): + text: str + config: Dict[str, Union[int, float]] = {} + stream: Optional[bool] = False + + @validator("text") + def text_must_not_be_empty(cls, v): + if not v.strip(): + raise ValueError("Empty prompt is not supported.") + return v + + @validator("config", pre=True) + def check_config_type(cls, value): + allowed_keys = ["max_new_tokens", "temperature", "top_p", "top_k"] + allowed_set = set(allowed_keys) + config_dict = value.keys() + config_keys = [key for key in config_dict] + config_set = set(config_keys) + + if not isinstance(value, dict): + raise ValueError("Config must be a dictionary") + + if not all(isinstance(key, str) for key in value.keys()): + raise ValueError("All keys in config must be strings") + + if not all(isinstance(val, (int, float)) for val in value.values()): + raise ValueError("All values in config must be integers or floats") + + if not config_set.issubset(allowed_set): + invalid_keys = config_set - allowed_set + raise ValueError(f'Invalid config keys: {", ".join(invalid_keys)}') + + return value + + @validator("stream", pre=True) + def check_stream_type(cls, value): + if not isinstance(value, bool) and value is not None: + raise ValueError("Stream must be a boolean or None") + return value + + +class SimpleModelResponse(BaseModel): + headers: Dict[str, str] + text: str + content: bytes + status_code: int + url: str + + class Config: + arbitrary_types_allowed = True + + response: Optional[requests.Response] = None + + @staticmethod + def from_requests_response(response: requests.Response): + return SimpleModelResponse( + headers=dict(response.headers), + text=response.text, + content=response.content, + status_code=response.status_code, + url=response.url, + response=response, + ) + + def iter_content( + self, chunk_size: Optional[int] = 1, decode_unicode: bool = False + ) -> Iterator[Union[bytes, str]]: + if self.response is not None: + return self.response.iter_content(chunk_size=chunk_size, decode_unicode=decode_unicode) + else: + return iter([]) diff --git a/llm_on_ray/inference/models/hpu/llama-2-70b-chat-hf-hpu.yaml b/llm_on_ray/inference/models/hpu/llama-2-70b-chat-hf-hpu.yaml index ab411ff0e..4ecf45cd0 100644 --- a/llm_on_ray/inference/models/hpu/llama-2-70b-chat-hf-hpu.yaml +++ b/llm_on_ray/inference/models/hpu/llama-2-70b-chat-hf-hpu.yaml @@ -8,7 +8,7 @@ deepspeed: true workers_per_group: 8 device: hpu model_description: - model_id_or_path: meta-llama/Llama-2-70b-chat-hf - tokenizer_name_or_path: meta-llama/Llama-2-70b-chat-hf + model_id_or_path: NousResearch/Llama-2-70b-chat-hf + tokenizer_name_or_path: NousResearch/Llama-2-70b-chat-hf config: use_auth_token: '' diff --git a/llm_on_ray/inference/models/hpu/llama-2-7b-chat-hf-hpu.yaml b/llm_on_ray/inference/models/hpu/llama-2-7b-chat-hf-hpu.yaml index b7b19f02a..cb57f2768 100644 --- a/llm_on_ray/inference/models/hpu/llama-2-7b-chat-hf-hpu.yaml +++ b/llm_on_ray/inference/models/hpu/llama-2-7b-chat-hf-hpu.yaml @@ -6,7 +6,7 @@ cpus_per_worker: 8 hpus_per_worker: 1 device: hpu model_description: - model_id_or_path: meta-llama/Llama-2-7b-chat-hf - tokenizer_name_or_path: meta-llama/Llama-2-7b-chat-hf + model_id_or_path: NousResearch/Llama-2-7b-chat-hf + tokenizer_name_or_path: NousResearch/Llama-2-7b-chat-hf config: use_auth_token: '' diff --git a/llm_on_ray/inference/models/hpu/llama-2-7b-chat-hf-vllm-hpu.yaml b/llm_on_ray/inference/models/hpu/llama-2-7b-chat-hf-vllm-hpu.yaml index 869f41d7a..a9d10ccbd 100644 --- a/llm_on_ray/inference/models/hpu/llama-2-7b-chat-hf-vllm-hpu.yaml +++ b/llm_on_ray/inference/models/hpu/llama-2-7b-chat-hf-vllm-hpu.yaml @@ -16,7 +16,7 @@ ipex: enabled: false precision: bf16 model_description: - model_id_or_path: meta-llama/Llama-2-7b-chat-hf - tokenizer_name_or_path: meta-llama/Llama-2-7b-chat-hf + model_id_or_path: NousResearch/Llama-2-7b-chat-hf + tokenizer_name_or_path: NousResearch/Llama-2-7b-chat-hf config: use_auth_token: '' diff --git a/llm_on_ray/inference/models/hpu/llama-3-70b-chat-hf-hpu.yaml b/llm_on_ray/inference/models/hpu/llama-3-70b-chat-hf-hpu.yaml index 32cf9bb4e..eb3bab468 100644 --- a/llm_on_ray/inference/models/hpu/llama-3-70b-chat-hf-hpu.yaml +++ b/llm_on_ray/inference/models/hpu/llama-3-70b-chat-hf-hpu.yaml @@ -7,7 +7,7 @@ deepspeed: true workers_per_group: 8 device: hpu model_description: - model_id_or_path: meta-llama/Meta-Llama-3-70b-Instruct - tokenizer_name_or_path: meta-llama/Meta-Llama-3-70b-Instruct + model_id_or_path: NousResearch/Meta-Llama-3-70B-Instruct + tokenizer_name_or_path: NousResearch/Meta-Llama-3-70B-Instruct config: use_auth_token: '' diff --git a/llm_on_ray/inference/models/hpu/llama-3-8b-instruct-hpu.yaml b/llm_on_ray/inference/models/hpu/llama-3-8b-instruct-hpu.yaml index d57ffcc22..3789ab984 100644 --- a/llm_on_ray/inference/models/hpu/llama-3-8b-instruct-hpu.yaml +++ b/llm_on_ray/inference/models/hpu/llama-3-8b-instruct-hpu.yaml @@ -6,7 +6,7 @@ cpus_per_worker: 8 hpus_per_worker: 1 device: hpu model_description: - model_id_or_path: meta-llama/Meta-Llama-3-8b-Instruct - tokenizer_name_or_path: meta-llama/Meta-Llama-3-8b-Instruct + model_id_or_path: NousResearch/Meta-Llama-3-8B-Instruct + tokenizer_name_or_path: NousResearch/Meta-Llama-3-8B-Instruct config: use_auth_token: '' diff --git a/llm_on_ray/inference/models/llama-2-7b-chat-hf.yaml b/llm_on_ray/inference/models/llama-2-7b-chat-hf.yaml index 1f648d857..81cb74d98 100644 --- a/llm_on_ray/inference/models/llama-2-7b-chat-hf.yaml +++ b/llm_on_ray/inference/models/llama-2-7b-chat-hf.yaml @@ -12,7 +12,7 @@ ipex: enabled: false precision: bf16 model_description: - model_id_or_path: meta-llama/Llama-2-7b-chat-hf - tokenizer_name_or_path: meta-llama/Llama-2-7b-chat-hf + model_id_or_path: NousResearch/Llama-2-7b-chat-hf + tokenizer_name_or_path: NousResearch/Llama-2-7b-chat-hf config: use_auth_token: '' diff --git a/llm_on_ray/inference/models/vllm/llama-2-7b-chat-hf-vllm-autoscaling.yaml b/llm_on_ray/inference/models/vllm/llama-2-7b-chat-hf-vllm-autoscaling.yaml index 207466a63..ba32990a6 100644 --- a/llm_on_ray/inference/models/vllm/llama-2-7b-chat-hf-vllm-autoscaling.yaml +++ b/llm_on_ray/inference/models/vllm/llama-2-7b-chat-hf-vllm-autoscaling.yaml @@ -22,7 +22,7 @@ ipex: enabled: false precision: bf16 model_description: - model_id_or_path: meta-llama/Llama-2-7b-chat-hf - tokenizer_name_or_path: meta-llama/Llama-2-7b-chat-hf + model_id_or_path: NousResearch/Llama-2-7b-chat-hf + tokenizer_name_or_path: NousResearch/Llama-2-7b-chat-hf config: use_auth_token: '' diff --git a/llm_on_ray/inference/models/vllm/llama-2-7b-chat-hf-vllm.yaml b/llm_on_ray/inference/models/vllm/llama-2-7b-chat-hf-vllm.yaml index 5db264c9e..29d562aa9 100644 --- a/llm_on_ray/inference/models/vllm/llama-2-7b-chat-hf-vllm.yaml +++ b/llm_on_ray/inference/models/vllm/llama-2-7b-chat-hf-vllm.yaml @@ -15,7 +15,7 @@ ipex: enabled: false precision: bf16 model_description: - model_id_or_path: meta-llama/Llama-2-7b-chat-hf - tokenizer_name_or_path: meta-llama/Llama-2-7b-chat-hf + model_id_or_path: NousResearch/Llama-2-7b-chat-hf + tokenizer_name_or_path: NousResearch/Llama-2-7b-chat-hf config: use_auth_token: '' diff --git a/llm_on_ray/inference/predictor_deployment.py b/llm_on_ray/inference/predictor_deployment.py index a1055915d..ed67f5119 100644 --- a/llm_on_ray/inference/predictor_deployment.py +++ b/llm_on_ray/inference/predictor_deployment.py @@ -34,6 +34,10 @@ ErrorResponse, ModelResponse, ) +from llm_on_ray.inference.api_simple_backend.simple_protocol import ( + SimpleRequest, + SimpleModelResponse, +) from llm_on_ray.inference.predictor import GenerateInput from llm_on_ray.inference.utils import get_prompt_format, PromptFormat from llm_on_ray.inference.api_openai_backend.tools import OpenAIToolsPrompter, ChatPromptCapture @@ -379,24 +383,18 @@ def preprocess_prompts( async def __call__(self, http_request: Request) -> Union[StreamingResponse, JSONResponse, str]: self.use_openai = False - try: - json_request: Dict[str, Any] = await http_request.json() + request: Dict[str, Any] = await http_request.json() except ValueError: return JSONResponse( status_code=400, content="Invalid JSON format from http request.", ) - streaming_response = json_request["stream"] if "stream" in json_request else False - input = json_request["text"] if "text" in json_request else "" - if input == "": - return JSONResponse( - status_code=400, - content="Empty prompt is not supported.", - ) - config = json_request["config"] if "config" in json_request else {} - # return prompt or list of prompts preprocessed + streaming_response = request["stream"] + input = request["text"] + config = request["config"] + prompts = self.preprocess_prompts(input) # Handle streaming response diff --git a/llm_on_ray/inference/predictors/deepspeed_predictor.py b/llm_on_ray/inference/predictors/deepspeed_predictor.py index e35fedbf0..2508286a7 100644 --- a/llm_on_ray/inference/predictors/deepspeed_predictor.py +++ b/llm_on_ray/inference/predictors/deepspeed_predictor.py @@ -53,11 +53,15 @@ def __init__(self, infer_conf: InferenceConfig, pad_token_id, stopping_criteria) model_desc = infer_conf.model_description model_config = model_desc.config + if infer_conf.model_description.config.use_auth_token: + auth_token = infer_conf.model_description.config.use_auth_token + else: + auth_token = None hf_config = AutoConfig.from_pretrained( model_desc.model_id_or_path, torchscript=True, trust_remote_code=model_config.trust_remote_code, - use_auth_token=infer_conf.model_description.config.use_auth_token, + use_auth_token=auth_token, ) # decide correct torch type for loading HF model @@ -75,7 +79,7 @@ def __init__(self, infer_conf: InferenceConfig, pad_token_id, stopping_criteria) self.model = PeftModel.from_pretrained( self.model, model_desc.peft_model_id_or_path, - use_auth_token=infer_conf.model_description.config.use_auth_token, + use_auth_token=auth_token, ) self.model = self.model.merge_and_unload() diff --git a/llm_on_ray/inference/predictors/hpu_predictor.py b/llm_on_ray/inference/predictors/hpu_predictor.py index 4710e0bf9..5e19c8733 100644 --- a/llm_on_ray/inference/predictors/hpu_predictor.py +++ b/llm_on_ray/inference/predictors/hpu_predictor.py @@ -314,11 +314,15 @@ def load_model(self): model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype) checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") + if model_desc.config.use_auth_token: + auth_token = model_desc.config.use_auth_token + else: + auth_token = None write_checkpoints_json( model_desc.model_id_or_path, self.local_rank, checkpoints_json, - token=model_desc.config.use_auth_token, + token=auth_token, ) else: with deepspeed.OnDevice(dtype=model_dtype, device="cpu"): diff --git a/llm_on_ray/inference/predictors/transformer_predictor.py b/llm_on_ray/inference/predictors/transformer_predictor.py index 2e51ab6a8..a840ee0ff 100644 --- a/llm_on_ray/inference/predictors/transformer_predictor.py +++ b/llm_on_ray/inference/predictors/transformer_predictor.py @@ -37,11 +37,15 @@ def __init__(self, infer_conf: InferenceConfig): super().__init__(infer_conf) model_desc = infer_conf.model_description model_config = model_desc.config + if infer_conf.model_description.config.use_auth_token: + auth_token = infer_conf.model_description.config.use_auth_token + else: + auth_token = None hf_config = AutoConfig.from_pretrained( model_desc.model_id_or_path, torchscript=True, trust_remote_code=model_config.trust_remote_code, - use_auth_token=infer_conf.model_description.config.use_auth_token, + use_auth_token=auth_token, ) # decide correct torch type for loading HF model @@ -74,7 +78,7 @@ def __init__(self, infer_conf: InferenceConfig): model = PeftModel.from_pretrained( model, model_desc.peft_model_id_or_path, - use_auth_token=infer_conf.model_description.config.use_auth_token, + use_auth_token=auth_token, ) model = model.merge_and_unload() diff --git a/llm_on_ray/inference/serve.py b/llm_on_ray/inference/serve.py index a84717664..ecd3bdee8 100644 --- a/llm_on_ray/inference/serve.py +++ b/llm_on_ray/inference/serve.py @@ -20,7 +20,11 @@ from llm_on_ray.inference.api_server_simple import serve_run from llm_on_ray.inference.api_server_openai import openai_serve_run from llm_on_ray.inference.predictor_deployment import PredictorDeployment -from llm_on_ray.inference.inference_config import ModelDescription, InferenceConfig, all_models +from llm_on_ray.inference.inference_config import ( + ModelDescription, + InferenceConfig, + all_models, +) def get_deployed_models(args): @@ -90,6 +94,11 @@ def main(argv=None): type=str, help=f"Only used when config_file is None, valid values can be any items in {list(all_models.keys())}.", ) + parser.add_argument( + "--list_model_ids", + action="store_true", + help="List all supported model IDs with config file path", + ) parser.add_argument( "--simple", action="store_true", @@ -130,6 +139,12 @@ def main(argv=None): args = parser.parse_args(argv) + all_models_name = list(all_models.keys()) + if args.list_model_ids: + for model in all_models_name: + print(f"{model}: \tllm_on_ray/inference/models/{model}.yaml") + sys.exit(0) + ray.init(address="auto") deployments, model_list = get_deployed_models(args) if args.simple: diff --git a/pyproject.toml b/pyproject.toml index 6a3c44685..5a8e89306 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ dependencies = [ "accelerate", "datasets>=2.14.6", - "numpy", + "numpy<2.0.0", "ray>=2.10", "ray[serve,tune]>=2.10", "typing>=3.7.4.3", diff --git a/tests/inference/test_serve.py b/tests/inference/test_serve.py index c9b757f9f..3958700d6 100644 --- a/tests/inference/test_serve.py +++ b/tests/inference/test_serve.py @@ -20,7 +20,7 @@ # Parametrize the test function with different combinations of parameters @pytest.mark.parametrize( - "config_file, models, port, simple, keep_serve_termimal", + "config_file, models, port, simple, keep_serve_termimal, list_model_ids", [ ( config_file, @@ -28,12 +28,14 @@ port, simple, keep_serve_termimal, + list_model_ids, ) for config_file in ["../.github/workflows/config/gpt2-ci.yaml"] for models in ["gpt2"] for port in [8000] for simple in [False] for keep_serve_termimal in [False] + for list_model_ids in [False, True] ], ) def test_script( @@ -42,25 +44,41 @@ def test_script( port, simple, keep_serve_termimal, + list_model_ids, ): - cmd_serve = ["python", "../llm_on_ray/inference/serve.py"] - if config_file is not None: - cmd_serve.append("--config_file") - cmd_serve.append(str(config_file)) - if models is not None: - cmd_serve.append("--models") - cmd_serve.append(str(models)) - if port is not None: - cmd_serve.append("--port") - cmd_serve.append(str(port)) - if simple: - cmd_serve.append("--simple") - if keep_serve_termimal: - cmd_serve.append("--keep_serve_termimal") + cmd_serve = ["llm_on_ray-serve"] + if list_model_ids: + cmd_serve.append("--list_model_ids") + else: + if config_file is not None: + cmd_serve.append("--config_file") + cmd_serve.append(str(config_file)) + elif models is not None: + cmd_serve.append("--models") + cmd_serve.append(str(models)) + if port is not None: + cmd_serve.append("--port") + cmd_serve.append(str(port)) + if simple: + cmd_serve.append("--simple") + if keep_serve_termimal: + cmd_serve.append("--keep_serve_termimal") + print(cmd_serve) result_serve = subprocess.run(cmd_serve, capture_output=True, text=True) + if list_model_ids: + output = result_serve.stdout.strip() + lines = output.split("\n") + assert len(lines) > 0, "No model IDs found in the output" - assert "Error" not in result_serve.stderr - assert result_serve.returncode == 0 - print("Output of stderr:") - print(result_serve.stderr) + # Check if the model IDs are listed + for line in lines: + parts = line.split() + assert len(parts) == 2, f"Invalid line format: {line}" + model_id, config_path = parts + + assert config_path.endswith(".yaml"), f"Invalid config path format: {config_path}" + + assert result_serve.returncode == 0, print( + "\n" + "Output of stderr: " + "\n", result_serve.stderr + ) diff --git a/tests/inference/test_simple_protocol.py b/tests/inference/test_simple_protocol.py new file mode 100644 index 000000000..070c8eb73 --- /dev/null +++ b/tests/inference/test_simple_protocol.py @@ -0,0 +1,90 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import subprocess +import pytest +import os +from basic_set import start_serve +import requests +from llm_on_ray.inference.api_simple_backend.simple_protocol import ( + SimpleRequest, + SimpleModelResponse, +) + + +executed_models = [] + + +# Parametrize the test function with different combinations of parameters +# TODO: more models and combinations will be added and tested. +@pytest.mark.parametrize( + "prompt,streaming_response,max_new_tokens,temperature,top_p, top_k", + [ + ( + prompt, + streaming_response, + max_new_tokens, + temperature, + top_p, + top_k, + ) + for prompt in ["Once upon a time", ""] + for streaming_response in [None, True, "error"] + for max_new_tokens in [None, 128, "error"] + for temperature in [None] + for top_p in [None] + for top_k in [None] + ], +) +def test_script(prompt, streaming_response, max_new_tokens, temperature, top_p, top_k): + global executed_models + + # Check if this modelname has already executed start_serve + if "gpt2" not in executed_models: + start_serve("gpt2", simple=True) + # Mark this modelname has already executed start_serve + executed_models.append("gpt2") + config = {} + if max_new_tokens: + config["max_new_tokens"] = max_new_tokens + if temperature: + config["temperature"] = temperature + if top_p: + config["top_p"] = top_p + if top_k: + config["top_k"] = top_k + + try: + sample_input = SimpleRequest(text=prompt, config=config, stream=streaming_response) + except ValueError as e: + print(e) + return + outputs = requests.post( + "http://localhost:8000/gpt2", + proxies={"http": None, "https": None}, # type: ignore + json=sample_input.dict(), + stream=streaming_response, + ) + + outputs.raise_for_status() + + simple_response = SimpleModelResponse.from_requests_response(outputs) + if streaming_response: + for output in simple_response.iter_content(chunk_size=1, decode_unicode=True): + print(output, end="", flush=True) + print() + else: + print(simple_response.text, flush=True)