Skip to content

Commit

Permalink
ENH: Supports multi functions in tool call for qwen2 (xorbitsai#2265)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChengjieLi28 authored Sep 11, 2024
1 parent c207bd3 commit 402cc7b
Showing 1 changed file with 78 additions and 84 deletions.
162 changes: 78 additions & 84 deletions xinference/model/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,99 +301,89 @@ def _to_chat_completion(completion: Completion) -> ChatCompletion:
}

@staticmethod
def _eval_glm_chat_arguments(c):
def _eval_glm_chat_arguments(c) -> List[Tuple]:
"""
Currently, glm4 tool call only supports one function
"""
try:
if isinstance(c, dict):
return None, c["name"], c["arguments"]
return [(None, c["name"], c["arguments"])]
except KeyError:
logger.error("Can't parse glm output: %s", c)
return str(c), None, None
return [(str(c), None, None)]
else:
return str(c), None, None
return [(str(c), None, None)]

@staticmethod
def _eval_qwen_chat_arguments(c):
text = c["choices"][0]["text"]
@classmethod
def _handle_qwen_tool_result(cls, text: str) -> List[Tuple]:
text: str = text.strip() # type: ignore
if text.startswith(QWEN_TOOL_CALL_SYMBOLS[0]):
text = text[len(QWEN_TOOL_CALL_SYMBOLS[0]) :]
if text.endswith(QWEN_TOOL_CALL_SYMBOLS[1]):
text = text[: -len(QWEN_TOOL_CALL_SYMBOLS[1])]
text = text.strip()
try:
content = json.loads(text)
return None, content["name"], content["arguments"]
except Exception as e:
logger.error("Can't parse qwen tool call output: %s. Error: %s", text, e)
return text, None, None
contents: List[str] = text.split(QWEN_TOOL_CALL_SYMBOLS[1])
results: List[Tuple] = []
for content in contents:
content = content.strip()
if content:
if content.startswith(QWEN_TOOL_CALL_SYMBOLS[0]):
content = content[len(QWEN_TOOL_CALL_SYMBOLS[0]) :]
content = content.strip()
try:
res = json.loads(content)
results.append((None, res["name"], res["arguments"]))
except Exception as e:
logger.error(
"Can't parse single qwen tool call output: %s. Error: %s",
content,
e,
)
results.append((content, None, None))
return results

@classmethod
def _eval_qwen_chat_arguments(cls, c) -> List[Tuple]:
text = c["choices"][0]["text"]
return cls._handle_qwen_tool_result(text)

@classmethod
def _eval_tool_arguments(cls, model_family, c):
family = model_family.model_family or model_family.model_name
if family in GLM4_TOOL_CALL_FAMILY:
content, func, args = cls._eval_glm_chat_arguments(c)
result = cls._eval_glm_chat_arguments(c)
elif family in QWEN_TOOL_CALL_FAMILY:
content, func, args = cls._eval_qwen_chat_arguments(c)
result = cls._eval_qwen_chat_arguments(c)
else:
raise Exception(
f"Model {model_family.model_name} is not support tool calls."
)
logger.debug("Tool call content: %s, func: %s, args: %s", content, func, args)
return content, func, args

@classmethod
def _tools_token_filter(cls, model_family):
"""
Generates a filter function for Qwen series models to retain outputs after "\nFinal Answer:".
Returns:
A function that takes tokens (string output by the model so far) and delta (new tokens added) as input,
returns the part after "\nFinal Answer:" if found, else returns delta.
"""
family = model_family.model_family or model_family.model_name
if family in QWEN_TOOL_CALL_FAMILY:
# Encapsulating function to reset 'found' after each call
found = False

def process_tokens(tokens: str, delta: str):
nonlocal found
# Once "Final Answer:" is found, future tokens are allowed.
if found:
return delta
# Check if the token ends with "\nFinal Answer:" and update `found`.
final_answer_idx = tokens.lower().rfind("\nfinal answer:")
if final_answer_idx != -1:
found = True
return tokens[final_answer_idx + len("\nfinal answer:") :]
return ""

return process_tokens
else:
return lambda tokens, delta: delta
logger.debug(f"Tool call content: {result}")
return result

@classmethod
def _tool_calls_completion_chunk(cls, model_family, model_uid, c):
_id = str(uuid.uuid4())
content, func, args = cls._eval_tool_arguments(model_family, c)
if func:
d = {
"role": "assistant",
"content": content,
"tool_calls": [
{
"id": f"call_{_id}",
"type": "function",
"function": {
"name": func,
"arguments": json.dumps(args, ensure_ascii=False),
},
}
],
}
finish_reason = "tool_calls"
else:
d = {"role": "assistant", "content": content, "tool_calls": []}
finish_reason = "stop"
tool_result = cls._eval_tool_arguments(model_family, c)
tool_calls = []
failed_contents = []
for content, func, args in tool_result:
if func:
tool_calls.append(
[
{
"id": f"call_{_id}",
"type": "function",
"function": {
"name": func,
"arguments": json.dumps(args, ensure_ascii=False),
},
}
]
)
else:
failed_contents.append(content)
finish_reason = "tool_calls" if tool_calls else "stop"
d = {
"role": "assistant",
"content": ". ".join(failed_contents) if failed_contents else None,
"tool_calls": tool_calls,
}
try:
usage = c.get("usage")
assert "prompt_tokens" in usage
Expand Down Expand Up @@ -422,12 +412,13 @@ def _tool_calls_completion_chunk(cls, model_family, model_uid, c):
@classmethod
def _tool_calls_completion(cls, model_family, model_uid, c):
_id = str(uuid.uuid4())
content, func, args = cls._eval_tool_arguments(model_family, c)
if func:
m = {
"role": "assistant",
"content": content,
"tool_calls": [
tool_result = cls._eval_tool_arguments(model_family, c)

tool_calls = []
failed_contents = []
for content, func, args in tool_result:
if func:
tool_calls.append(
{
"id": f"call_{_id}",
"type": "function",
Expand All @@ -436,12 +427,15 @@ def _tool_calls_completion(cls, model_family, model_uid, c):
"arguments": json.dumps(args, ensure_ascii=False),
},
}
],
}
finish_reason = "tool_calls"
else:
m = {"role": "assistant", "content": content, "tool_calls": []}
finish_reason = "stop"
)
else:
failed_contents.append(content)
finish_reason = "tool_calls" if tool_calls else "stop"
m = {
"role": "assistant",
"content": ". ".join(failed_contents) if failed_contents else None,
"tool_calls": tool_calls,
}
try:
usage = c.get("usage")
assert "prompt_tokens" in usage
Expand Down

0 comments on commit 402cc7b

Please sign in to comment.