Skip to content

Commit

Permalink
Merge pull request #699 from sunyuhan19981208/api_do_sample
Browse files Browse the repository at this point in the history
Add do_sample argument for API demo
  • Loading branch information
ymcui authored Jul 5, 2023
2 parents 4ef9477 + 8aca35f commit 8e45406
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
4 changes: 4 additions & 0 deletions scripts/openai_server_demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ json返回体:

`repetition_penalty`: 重复惩罚,具体细节可以参考这篇文章:<https://arxiv.org/pdf/1909.05858.pdf>

`do_sample`: 启用随机采样策略。默认为true。

### 聊天(chat completion)

聊天接口支持多轮对话
Expand Down Expand Up @@ -240,6 +242,8 @@ json返回体:

`repetition_penalty`: 重复惩罚,具体细节可以参考这篇文章:<https://arxiv.org/pdf/1909.05858.pdf>

`do_sample`: 启用随机采样策略。默认为true。

### 文本嵌入向量(text embedding)

文本嵌入向量有很多作用,包括但不限于基于大型文档问答、总结一本书中的内容、为大语言模型找到与当前用户输入最相近的记忆等等。
Expand Down
6 changes: 4 additions & 2 deletions scripts/openai_server_demo/openai_api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ class ChatCompletionRequest(BaseModel):
top_k: Optional[int] = 40
n: Optional[int] = 1
max_tokens: Optional[int] = 128
num_beams: Optional[int] = 4
num_beams: Optional[int] = 1
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
repetition_penalty: Optional[float] = 1.0
user: Optional[str] = None
do_sample: Optional[bool] = True


class ChatMessage(BaseModel):
Expand Down Expand Up @@ -58,11 +59,12 @@ class CompletionRequest(BaseModel):
stream: Optional[bool] = False
top_p: Optional[float] = 0.75
top_k: Optional[int] = 40
num_beams: Optional[int] = 4
num_beams: Optional[int] = 1
logprobs: Optional[int] = None
echo: Optional[bool] = False
repetition_penalty: Optional[float] = 1.0
user: Optional[str] = None
do_sample: Optional[bool] = True


class CompletionResponseChoice(BaseModel):
Expand Down
14 changes: 5 additions & 9 deletions scripts/openai_server_demo/openai_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,7 @@
EmbeddingsRequest,
EmbeddingsResponse,
)
generation_config = dict(
temperature=0.2,
top_k=40,
top_p=0.9,
do_sample=True,
num_beams=1,
repetition_penalty=1.1,
max_new_tokens=400
)

load_type = torch.float16
if torch.cuda.is_available():
device = torch.device(0)
Expand Down Expand Up @@ -113,6 +105,7 @@ def predict(
top_k=40,
num_beams=4,
repetition_penalty=1.0,
do_sample=True,
**kwargs,
):
"""
Expand All @@ -131,6 +124,7 @@ def predict(
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
do_sample=do_sample,
**kwargs,
)
with torch.no_grad():
Expand Down Expand Up @@ -188,6 +182,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
temperature=request.temperature,
num_beams=request.num_beams,
repetition_penalty=request.repetition_penalty,
do_sample=request.do_sample,
)
choices = [ChatCompletionResponseChoice(index = i, message = msg) for i, msg in enumerate(msgs)]
choices += [ChatCompletionResponseChoice(index = len(choices), message = ChatMessage(role='assistant',content=output))]
Expand All @@ -204,6 +199,7 @@ async def create_completion(request: CompletionRequest):
temperature=request.temperature,
num_beams=request.num_beams,
repetition_penalty=request.repetition_penalty,
do_sample=request.do_sample,
)
choices = [CompletionResponseChoice(index = 0, text = output)]
return CompletionResponse(choices = choices)
Expand Down

0 comments on commit 8e45406

Please sign in to comment.