From 62831b1903cf96d4106c3ffd636fde706d261a5b Mon Sep 17 00:00:00 2001 From: KepingYan Date: Thu, 16 Nov 2023 11:30:35 +0800 Subject: [PATCH] fix bug of serving llama2 on UI (#120) --- inference/chat_process.py | 2 +- inference/config.py | 5 +++-- inference/start_ui.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/inference/chat_process.py b/inference/chat_process.py index 421b8898a..2df19dad0 100644 --- a/inference/chat_process.py +++ b/inference/chat_process.py @@ -84,7 +84,7 @@ def prepare_prompt(self, messages: list): role, content = msg["role"], msg["content"] if role == "user": if self.human_id != "": - prompt += self.human_id.format(content) + prompt += self.human_id.format(msg=content) else: prompt += f"{content}\n" elif role == "assistant": diff --git a/inference/config.py b/inference/config.py index 78139ea49..274be60e5 100644 --- a/inference/config.py +++ b/inference/config.py @@ -119,7 +119,7 @@ } } -llama2 = { +llama2_7b = { "model_id_or_path": "meta-llama/Llama-2-7b-chat-hf", "tokenizer_name_or_path": "meta-llama/Llama-2-7b-chat-hf", "port": "8000", @@ -143,7 +143,7 @@ "bloom": bloom, "opt": opt, "mpt": mpt, - "llama2": llama2 + "llama2_7b": llama2_7b } env_model = "MODEL_TO_SERVE" @@ -155,3 +155,4 @@ base_models["gpt2"] = gpt2 base_models["gpt-j-6B"] = gpt_j_6B +base_models["llama2-7b"] = llama2_7b diff --git a/inference/start_ui.py b/inference/start_ui.py index 661b7c2f3..5cd63f4fd 100644 --- a/inference/start_ui.py +++ b/inference/start_ui.py @@ -2,7 +2,7 @@ from config import all_models, base_models import time import os -from chat_process import ChatModelGptJ +from chat_process import ChatModelGptJ, ChatModelLLama import torch from run_model_serve import PredictDeployment from ray import serve