From efa1f01b59ed0b79bdf7b0abd3111dc5fa5a6005 Mon Sep 17 00:00:00 2001 From: Jiayi Ni Date: Mon, 21 Aug 2023 13:02:26 +0800 Subject: [PATCH 1/4] use restful client --- xinference/deploy/cmdline.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/xinference/deploy/cmdline.py b/xinference/deploy/cmdline.py index 6b67c67c61..ce84010b9a 100644 --- a/xinference/deploy/cmdline.py +++ b/xinference/deploy/cmdline.py @@ -24,7 +24,6 @@ from .. import __version__ from ..client import ( - Client, RESTfulChatglmCppChatModelHandle, RESTfulChatModelHandle, RESTfulClient, @@ -354,9 +353,7 @@ def model_generate( ): endpoint = get_endpoint(endpoint) if stream: - # TODO: when stream=True, RestfulClient cannot generate words one by one. - # So use Client in temporary. The implementation needs to be changed to - # RestfulClient in the future. + async def generate_internal(): while True: # the prompt will be written to stdout. @@ -365,7 +362,7 @@ async def generate_internal(): if prompt == "": break print(f"Completion: {prompt}", end="", file=sys.stdout) - async for chunk in model.generate( + for chunk in model.generate( prompt=prompt, generate_config={"stream": stream, "max_tokens": max_tokens}, ): @@ -376,7 +373,7 @@ async def generate_internal(): print(choice["text"], end="", flush=True, file=sys.stdout) print("\n", file=sys.stdout) - client = Client(endpoint=endpoint) + client = RESTfulClient(base_url=endpoint) model = client.get_model(model_uid=model_uid) loop = asyncio.get_event_loop() @@ -436,9 +433,7 @@ def model_chat( endpoint = get_endpoint(endpoint) chat_history: "List[ChatCompletionMessage]" = [] if stream: - # TODO: when stream=True, RestfulClient cannot generate words one by one. - # So use Client in temporary. The implementation needs to be changed to - # RestfulClient in the future. + async def chat_internal(): while True: # the prompt will be written to stdout. @@ -449,7 +444,7 @@ async def chat_internal(): chat_history.append(ChatCompletionMessage(role="user", content=prompt)) print("Assistant: ", end="", file=sys.stdout) response_content = "" - async for chunk in model.chat( + for chunk in model.chat( prompt=prompt, chat_history=chat_history, generate_config={"stream": stream, "max_tokens": max_tokens}, @@ -465,7 +460,7 @@ async def chat_internal(): ChatCompletionMessage(role="assistant", content=response_content) ) - client = Client(endpoint=endpoint) + client = RESTfulClient(base_url=endpoint) model = client.get_model(model_uid=model_uid) loop = asyncio.get_event_loop() From 9a1cfcda62e3b45b4711eb3453fafb14bb309a08 Mon Sep 17 00:00:00 2001 From: Jiayi Ni Date: Fri, 25 Aug 2023 15:58:17 +0800 Subject: [PATCH 2/4] fix test_cmdline --- xinference/deploy/test/test_cmdline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xinference/deploy/test/test_cmdline.py b/xinference/deploy/test/test_cmdline.py index 2d9be3be32..dbf28bed1c 100644 --- a/xinference/deploy/test/test_cmdline.py +++ b/xinference/deploy/test/test_cmdline.py @@ -18,7 +18,7 @@ import pytest from click.testing import CliRunner -from ...client import Client +from ...client import RESTfulClient from ..cmdline import ( list_model_registrations, model_chat, @@ -59,7 +59,7 @@ def test_cmdline(setup, stream): """ # if use `model_launch` command to launch model, CI will fail. # So use client to launch model in temporary - client = Client(endpoint) + client = RESTfulClient(endpoint) model_uid = client.launch_model( model_name="orca", model_size_in_billions=3, quantization="q4_0" ) From 049e524ec1eb246e30cdb734cec2e8e097294ccd Mon Sep 17 00:00:00 2001 From: UranusSeven <109661872+UranusSeven@users.noreply.github.com> Date: Tue, 29 Aug 2023 15:53:59 +0800 Subject: [PATCH 3/4] Remove isolation --- xinference/deploy/cmdline.py | 185 ++++++++++++----------------------- 1 file changed, 62 insertions(+), 123 deletions(-) diff --git a/xinference/deploy/cmdline.py b/xinference/deploy/cmdline.py index ce84010b9a..78ad97f21d 100644 --- a/xinference/deploy/cmdline.py +++ b/xinference/deploy/cmdline.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import configparser import logging import os @@ -35,7 +34,6 @@ XINFERENCE_DEFAULT_LOCAL_HOST, XINFERENCE_ENV_ENDPOINT, ) -from ..isolation import Isolation from ..types import ChatCompletionMessage try: @@ -352,66 +350,38 @@ def model_generate( stream: bool, ): endpoint = get_endpoint(endpoint) - if stream: - - async def generate_internal(): - while True: - # the prompt will be written to stdout. - # https://docs.python.org/3.10/library/functions.html#input - prompt = input("Prompt: ") - if prompt == "": - break - print(f"Completion: {prompt}", end="", file=sys.stdout) - for chunk in model.generate( - prompt=prompt, - generate_config={"stream": stream, "max_tokens": max_tokens}, - ): - choice = chunk["choices"][0] - if "text" not in choice: - continue - else: - print(choice["text"], end="", flush=True, file=sys.stdout) - print("\n", file=sys.stdout) - - client = RESTfulClient(base_url=endpoint) - model = client.get_model(model_uid=model_uid) - - loop = asyncio.get_event_loop() - coro = generate_internal() - - if loop.is_running(): - isolation = Isolation(asyncio.new_event_loop(), threaded=True) - isolation.start() - isolation.call(coro) + client = RESTfulClient(base_url=endpoint) + model = client.get_model(model_uid=model_uid) + if not isinstance(model, (RESTfulChatModelHandle, RESTfulGenerateModelHandle)): + raise ValueError(f"model {model_uid} has no generate method") + + while True: + # the prompt will be written to stdout. + # https://docs.python.org/3.10/library/functions.html#input + prompt = input("Prompt: ") + if prompt.lower() == "exit" or prompt.lower() == "quit": + break + print(f"Completion: {prompt}", end="", file=sys.stdout) + + if stream: + for chunk in model.generate( + prompt=prompt, + generate_config={"stream": stream, "max_tokens": max_tokens}, + ): + choice = chunk["choices"][0] + if "text" not in choice: + continue + else: + print(choice["text"], end="", flush=True, file=sys.stdout) else: - task = loop.create_task(coro) - try: - loop.run_until_complete(task) - except KeyboardInterrupt: - task.cancel() - loop.run_until_complete(task) - # avoid displaying exception-unhandled warnings - task.exception() - else: - restful_client = RESTfulClient(base_url=endpoint) - restful_model = restful_client.get_model(model_uid=model_uid) - if not isinstance( - restful_model, (RESTfulChatModelHandle, RESTfulGenerateModelHandle) - ): - raise ValueError(f"model {model_uid} has no generate method") - - while True: - prompt = input("User: ") - if prompt == "": - break - print(f"Assistant: {prompt}", end="", file=sys.stdout) - response = restful_model.generate( + response = model.generate( prompt=prompt, generate_config={"stream": stream, "max_tokens": max_tokens}, ) if not isinstance(response, dict): raise ValueError("generate result is not valid") - print(f"{response['choices'][0]['text']}\n", file=sys.stdout) + print(f"{response['choices'][0]['text']}", file=sys.stdout) + print("\n", file=sys.stdout) @cli.command("chat") @@ -431,80 +401,49 @@ def model_chat( ): # TODO: chat model roles may not be user and assistant. endpoint = get_endpoint(endpoint) + client = RESTfulClient(base_url=endpoint) + model = client.get_model(model_uid=model_uid) + if not isinstance( + model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle) + ): + raise ValueError(f"model {model_uid} has no chat method") + chat_history: "List[ChatCompletionMessage]" = [] - if stream: - - async def chat_internal(): - while True: - # the prompt will be written to stdout. - # https://docs.python.org/3.10/library/functions.html#input - prompt = input("User: ") - if prompt == "": - break - chat_history.append(ChatCompletionMessage(role="user", content=prompt)) - print("Assistant: ", end="", file=sys.stdout) - response_content = "" - for chunk in model.chat( - prompt=prompt, - chat_history=chat_history, - generate_config={"stream": stream, "max_tokens": max_tokens}, - ): - delta = chunk["choices"][0]["delta"] - if "content" not in delta: - continue - else: - response_content += delta["content"] - print(delta["content"], end="", flush=True, file=sys.stdout) - print("\n", file=sys.stdout) - chat_history.append( - ChatCompletionMessage(role="assistant", content=response_content) - ) - - client = RESTfulClient(base_url=endpoint) - model = client.get_model(model_uid=model_uid) - - loop = asyncio.get_event_loop() - coro = chat_internal() - - if loop.is_running(): - isolation = Isolation(asyncio.new_event_loop(), threaded=True) - isolation.start() - isolation.call(coro) + while True: + # the prompt will be written to stdout. + # https://docs.python.org/3.10/library/functions.html#input + prompt = input("User: ") + if prompt == "": + break + chat_history.append(ChatCompletionMessage(role="user", content=prompt)) + print("Assistant: ", end="", file=sys.stdout) + + response_content = "" + if stream: + for chunk in model.chat( + prompt=prompt, + chat_history=chat_history, + generate_config={"stream": stream, "max_tokens": max_tokens}, + ): + delta = chunk["choices"][0]["delta"] + if "content" not in delta: + continue + else: + response_content += delta["content"] + print(delta["content"], end="", flush=True, file=sys.stdout) else: - task = loop.create_task(coro) - try: - loop.run_until_complete(task) - except KeyboardInterrupt: - task.cancel() - loop.run_until_complete(task) - # avoid displaying exception-unhandled warnings - task.exception() - else: - restful_client = RESTfulClient(base_url=endpoint) - restful_model = restful_client.get_model(model_uid=model_uid) - if not isinstance( - restful_model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle) - ): - raise ValueError(f"model {model_uid} has no chat method") - - while True: - prompt = input("User: ") - if prompt == "": - break - chat_history.append(ChatCompletionMessage(role="user", content=prompt)) - print("Assistant: ", end="", file=sys.stdout) - response = restful_model.chat( + response = model.chat( prompt=prompt, chat_history=chat_history, generate_config={"stream": stream, "max_tokens": max_tokens}, ) - if not isinstance(response, dict): - raise ValueError("chat result is not valid") response_content = response["choices"][0]["message"]["content"] - print(f"{response_content}\n", file=sys.stdout) - chat_history.append( - ChatCompletionMessage(role="assistant", content=response_content) - ) + print(f"{response_content}", file=sys.stdout) + + chat_history.append( + ChatCompletionMessage(role="assistant", content=response_content) + ) + print("\n", file=sys.stdout) if __name__ == "__main__": From 6b4c06f07419d2e329fc86c78898eb6214d53f3c Mon Sep 17 00:00:00 2001 From: UranusSeven <109661872+UranusSeven@users.noreply.github.com> Date: Wed, 30 Aug 2023 10:57:15 +0800 Subject: [PATCH 4/4] Fix mypy --- xinference/deploy/cmdline.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/xinference/deploy/cmdline.py b/xinference/deploy/cmdline.py index 78ad97f21d..dcb9dbedaa 100644 --- a/xinference/deploy/cmdline.py +++ b/xinference/deploy/cmdline.py @@ -364,10 +364,12 @@ def model_generate( print(f"Completion: {prompt}", end="", file=sys.stdout) if stream: - for chunk in model.generate( + iter = model.generate( prompt=prompt, generate_config={"stream": stream, "max_tokens": max_tokens}, - ): + ) + assert not isinstance(iter, dict) + for chunk in iter: choice = chunk["choices"][0] if "text" not in choice: continue @@ -378,8 +380,7 @@ def model_generate( prompt=prompt, generate_config={"stream": stream, "max_tokens": max_tokens}, ) - if not isinstance(response, dict): - raise ValueError("generate result is not valid") + assert isinstance(response, dict) print(f"{response['choices'][0]['text']}", file=sys.stdout) print("\n", file=sys.stdout) @@ -420,11 +421,13 @@ def model_chat( response_content = "" if stream: - for chunk in model.chat( + iter = model.chat( prompt=prompt, chat_history=chat_history, generate_config={"stream": stream, "max_tokens": max_tokens}, - ): + ) + assert not isinstance(iter, dict) + for chunk in iter: delta = chunk["choices"][0]["delta"] if "content" not in delta: continue @@ -437,6 +440,7 @@ def model_chat( chat_history=chat_history, generate_config={"stream": stream, "max_tokens": max_tokens}, ) + assert isinstance(response, dict) response_content = response["choices"][0]["message"]["content"] print(f"{response_content}", file=sys.stdout)