Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support native model ability to invoke tools #50

Merged
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b6ba03c
Support native tools (initial commit)
Oleksii-Klimov Dec 14, 2023
99e8322
Merge branch 'development' into 41-support-native-model-ability-to-in…
Oleksii-Klimov Jan 2, 2024
d55b24d
Some intermediate fixes.
Oleksii-Klimov Jan 4, 2024
eb0d90a
Small fixes.
Oleksii-Klimov Jan 5, 2024
d6da3a9
Merge branch 'development' into 41-support-native-model-ability-to-in…
Oleksii-Klimov Jan 5, 2024
cb5cbc9
More fixes.
Oleksii-Klimov Jan 8, 2024
f513d56
Merge branch 'development' into 41-support-native-model-ability-to-in…
Oleksii-Klimov Jan 8, 2024
cad2dc1
Check for a reserved command name.
Oleksii-Klimov Jan 9, 2024
d012fee
Add extra line between commands.
Oleksii-Klimov Jan 10, 2024
41464dc
Update dial sdk to support httpx for opentelemetry.
Oleksii-Klimov Jan 10, 2024
5db6bb0
Remove unused import.
Oleksii-Klimov Jan 10, 2024
076c2ea
Clarify prompts.
Oleksii-Klimov Jan 10, 2024
eb88263
Minor prompt adjustments.
Oleksii-Klimov Jan 10, 2024
762cbc2
Improve prompt formatting for gpt-4-0314.
Oleksii-Klimov Jan 11, 2024
55886a1
Use latest openai api version to support tools.
Oleksii-Klimov Jan 11, 2024
ca0a1db
Address review comments.
Oleksii-Klimov Jan 12, 2024
4b33f4d
Rename method.
Oleksii-Klimov Jan 12, 2024
6581c9c
Remove redundant comment.
Oleksii-Klimov Jan 12, 2024
f7ef289
Fix tests.
Oleksii-Klimov Jan 12, 2024
c24d81e
Fix typo.
Oleksii-Klimov Jan 12, 2024
5663d6b
Remove redundant import.
Oleksii-Klimov Jan 12, 2024
10b613f
Add comment to clarify logic.
Oleksii-Klimov Jan 12, 2024
73d95eb
Prompt clarifications.
Oleksii-Klimov Jan 15, 2024
2dbf2fd
Address review comments.
Oleksii-Klimov Jan 15, 2024
bb11c20
Update .env.example.
Oleksii-Klimov Jan 15, 2024
94a93a0
Use official model name.
Oleksii-Klimov Jan 15, 2024
02844bd
Merge branch 'development' into 41-support-native-model-ability-to-in…
Oleksii-Klimov Jan 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions aidial_assistant/app.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,33 @@
#!/usr/bin/env python3
import logging.config
import os
from pathlib import Path

from aidial_sdk import DIALApp
from aidial_sdk.telemetry.types import TelemetryConfig, TracingConfig
from starlette.responses import Response

from aidial_assistant.application.assistant_application import (
AssistantApplication,
)
from aidial_assistant.utils.log_config import get_log_config

log_level = os.getenv("LOG_LEVEL", "INFO")
config_dir = Path(os.getenv("CONFIG_DIR", "aidial_assistant/configs"))

logging.config.dictConfig(get_log_config(log_level))

telemetry_config = TelemetryConfig(
service_name="aidial-assistant", tracing=TracingConfig()
)
app = DIALApp(telemetry_config=telemetry_config)
app.add_chat_completion("assistant", AssistantApplication(config_dir))

# A delayed import is necessary to set up the httpx hook before the openai client inherits from AsyncClient.
from aidial_assistant.application.assistant_application import ( # noqa: E402
AssistantApplication,
)

@app.get("/healthcheck/status200")
def status200() -> Response:
return Response("Service is running...", status_code=200)
config_dir = Path(os.getenv("CONFIG_DIR", "aidial_assistant/configs"))
tools_supporting_deployments: set[str] = set(
os.getenv(
"TOOLS_SUPPORTING_DEPLOYMENTS", "gpt-4-turbo-1106,anthropic.claude-v2-1"
).split(",")
)
app.add_chat_completion(
"assistant",
AssistantApplication(config_dir, tools_supporting_deployments),
)
7 changes: 5 additions & 2 deletions aidial_assistant/application/addons_dialogue_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
LimitExceededException,
ModelRequestLimiter,
)
from aidial_assistant.model.model_client import Message, ModelClient
from aidial_assistant.model.model_client import (
ChatCompletionMessageParam,
ModelClient,
)


class AddonsDialogueLimiter(ModelRequestLimiter):
Expand All @@ -16,7 +19,7 @@ def __init__(self, max_dialogue_tokens: int, model_client: ModelClient):
self._initial_tokens: int | None = None

@override
async def verify_limit(self, messages: list[Message]):
async def verify_limit(self, messages: list[ChatCompletionMessageParam]):
if self._initial_tokens is None:
self._initial_tokens = await self.model_client.count_tokens(
messages
Expand Down
153 changes: 126 additions & 27 deletions aidial_assistant/application/assistant_application.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging
from pathlib import Path
from typing import Tuple

from aidial_sdk.chat_completion import FinishReason
from aidial_sdk.chat_completion.base import ChatCompletion
from aidial_sdk.chat_completion.request import Addon, Message, Request, Role
from aidial_sdk.chat_completion.response import Response
from openai.lib.azure import AsyncAzureOpenAI
from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel

from aidial_assistant.application.addons_dialogue_limiter import (
Expand All @@ -18,18 +21,29 @@
MAIN_BEST_EFFORT_TEMPLATE,
MAIN_SYSTEM_DIALOG_MESSAGE,
)
from aidial_assistant.chain.command_chain import CommandChain, CommandDict
from aidial_assistant.chain.command_chain import (
CommandChain,
CommandConstructor,
CommandDict,
)
from aidial_assistant.chain.history import History
from aidial_assistant.commands.reply import Reply
from aidial_assistant.commands.run_plugin import PluginInfo, RunPlugin
from aidial_assistant.commands.run_tool import RunTool
from aidial_assistant.model.model_client import (
ModelClient,
ReasonLengthException,
)
from aidial_assistant.tools_chain.tools_chain import (
CommandToolDict,
ToolsChain,
convert_commands_to_tools,
)
from aidial_assistant.utils.exceptions import (
RequestParameterValidationError,
unhandled_exception_handler,
)
from aidial_assistant.utils.open_ai import construct_tool
from aidial_assistant.utils.open_ai_plugin import (
AddonTokenSource,
get_open_ai_plugin_info,
Expand All @@ -49,8 +63,6 @@ def _get_request_args(request: Request) -> dict[str, str]:
args = {
"model": request.model,
"temperature": request.temperature,
"api_version": request.api_version,
"api_key": request.api_key,
"user": request.user,
}

Expand Down Expand Up @@ -83,68 +95,114 @@ def _validate_messages(messages: list[Message]) -> None:
)


def _construct_tool(name: str, description: str) -> ChatCompletionToolParam:
return construct_tool(
name,
description,
{
"query": {
"type": "string",
"description": "A task written in natural language",
}
},
["query"],
)


class AssistantApplication(ChatCompletion):
def __init__(self, config_dir: Path):
def __init__(
self, config_dir: Path, tools_supporting_deployments: set[str]
):
self.args = parse_args(config_dir)
self.tools_supporting_deployments = tools_supporting_deployments

@unhandled_exception_handler
async def chat_completion(
self, request: Request, response: Response
) -> None:
_validate_messages(request.messages)
addon_references = _validate_addons(request.addons)
chat_args = self.args.openai_conf.dict() | _get_request_args(request)
chat_args = _get_request_args(request)

model = ModelClient(
model_args=chat_args
| {
"deployment_id": chat_args["model"],
"api_type": "azure",
"stream": True,
},
buffer_size=self.args.chat_conf.buffer_size,
client=AsyncAzureOpenAI(
azure_endpoint=self.args.openai_conf.api_base,
api_key=request.api_key,
# 2023-12-01-preview is needed to support tools
api_version="2023-12-01-preview",
),
model_args=chat_args,
)

token_source = AddonTokenSource(
request.headers,
(addon_reference.url for addon_reference in addon_references),
)

addons: dict[str, PluginInfo] = {}
plugins: list[PluginInfo] = []
# DIAL Core has own names for addons, so in stages we need to map them to the names used by the user
addon_name_mapping: dict[str, str] = {}
for addon_reference in addon_references:
info = await get_open_ai_plugin_info(addon_reference.url)
addons[info.ai_plugin.name_for_model] = PluginInfo(
info=info,
auth=get_plugin_auth(
info.ai_plugin.auth.type,
info.ai_plugin.auth.authorization_type,
addon_reference.url,
token_source,
),
plugins.append(
PluginInfo(
info=info,
auth=get_plugin_auth(
info.ai_plugin.auth.type,
info.ai_plugin.auth.authorization_type,
addon_reference.url,
token_source,
),
)
)

if addon_reference.name:
addon_name_mapping[
info.ai_plugin.name_for_model
] = addon_reference.name

if request.model in self.tools_supporting_deployments:
await AssistantApplication._run_native_tools_chat(
model, plugins, addon_name_mapping, request, response
)
else:
await AssistantApplication._run_emulated_tools_chat(
model, plugins, addon_name_mapping, request, response
)

@staticmethod
async def _run_emulated_tools_chat(
model: ModelClient,
addons: list[PluginInfo],
addon_name_mapping: dict[str, str],
request: Request,
response: Response,
):
# TODO: Add max_addons_dialogue_tokens as a request parameter
max_addons_dialogue_tokens = 1000

def create_command(addon: PluginInfo):
return lambda: RunPlugin(model, addon, max_addons_dialogue_tokens)

command_dict: CommandDict = {
RunPlugin.token(): lambda: RunPlugin(
model, addons, max_addons_dialogue_tokens
),
Reply.token(): Reply,
addon.info.ai_plugin.name_for_model: create_command(addon)
for addon in addons
}
if Reply.token() in command_dict:
RequestParameterValidationError(
f"Addon with name '{Reply.token()}' is not allowed for model {request.model}.",
param="addons",
)

command_dict[Reply.token()] = Reply

chain = CommandChain(
model_client=model, name="ASSISTANT", command_dict=command_dict
)
addon_descriptions = {
name: addon.info.open_api.info.description
addon.info.ai_plugin.name_for_model: addon.info.open_api.info.description
or addon.info.ai_plugin.description_for_human
for name, addon in addons.items()
for addon in addons
}
history = History(
assistant_system_message_template=MAIN_SYSTEM_DIALOG_MESSAGE.build(
Expand Down Expand Up @@ -187,3 +245,44 @@ async def chat_completion(

if discarded_messages is not None:
response.set_discarded_messages(discarded_messages)

@staticmethod
async def _run_native_tools_chat(
model: ModelClient,
plugins: list[PluginInfo],
addon_name_mapping: dict[str, str],
request: Request,
response: Response,
):
def create_command_tool(
plugin: PluginInfo,
) -> Tuple[CommandConstructor, ChatCompletionToolParam]:
return lambda: RunTool(model, plugin), _construct_tool(
plugin.info.ai_plugin.name_for_model,
plugin.info.ai_plugin.description_for_human,
)

command_tool_dict: CommandToolDict = {
plugin.info.ai_plugin.name_for_model: create_command_tool(plugin)
for plugin in plugins
}
chain = ToolsChain(model, command_tool_dict)

choice = response.create_single_choice()
choice.open()

callback = AssistantChainCallback(choice, addon_name_mapping)
finish_reason = FinishReason.STOP
messages = convert_commands_to_tools(parse_history(request.messages))
try:
await chain.run_chat(messages, callback)
except ReasonLengthException:
finish_reason = FinishReason.LENGTH

if callback.invocations:
choice.set_state(State(invocations=callback.invocations))
choice.close(finish_reason)

response.set_usage(
model.total_prompt_tokens, model.total_completion_tokens
)
58 changes: 2 additions & 56 deletions aidial_assistant/application/assistant_callback.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,18 @@
from types import TracebackType
from typing import Callable

from aidial_sdk.chat_completion import Status
from aidial_sdk.chat_completion.choice import Choice
from aidial_sdk.chat_completion.stage import Stage
from typing_extensions import override

from aidial_assistant.chain.callbacks.arg_callback import ArgCallback
from aidial_assistant.chain.callbacks.args_callback import ArgsCallback
from aidial_assistant.chain.callbacks.chain_callback import ChainCallback
from aidial_assistant.chain.callbacks.command_callback import CommandCallback
from aidial_assistant.chain.callbacks.result_callback import ResultCallback
from aidial_assistant.commands.base import ExecutionCallback, ResultObject
from aidial_assistant.commands.run_plugin import RunPlugin
from aidial_assistant.utils.state import Invocation


class PluginNameArgCallback(ArgCallback):
def __init__(
self,
callback: Callable[[str], None],
addon_name_mapping: dict[str, str],
):
super().__init__(0, callback)
self.addon_name_mapping = addon_name_mapping

self._plugin_name = ""

@override
def on_arg(self, chunk: str):
chunk = chunk.replace('"', "")
self._plugin_name += chunk

@override
def on_arg_end(self):
self.callback(
self.addon_name_mapping.get(self._plugin_name, self._plugin_name)
+ "("
)


class RunPluginArgsCallback(ArgsCallback):
def __init__(
self,
callback: Callable[[str], None],
addon_name_mapping: dict[str, str],
):
super().__init__(callback)
self.addon_name_mapping = addon_name_mapping

@override
def on_args_start(self):
pass

@override
def arg_callback(self) -> ArgCallback:
self.arg_index += 1
if self.arg_index == 0:
return PluginNameArgCallback(self.callback, self.addon_name_mapping)
else:
return ArgCallback(self.arg_index - 1, self.callback)


class AssistantCommandCallback(CommandCallback):
def __init__(self, stage: Stage, addon_name_mapping: dict[str, str]):
self.stage = stage
Expand All @@ -71,20 +22,15 @@ def __init__(self, stage: Stage, addon_name_mapping: dict[str, str]):

@override
def on_command(self, command: str):
if command == RunPlugin.token():
self._args_callback = RunPluginArgsCallback(
self._on_stage_name, self.addon_name_mapping
)
else:
self._on_stage_name(command)
self._on_stage_name(self.addon_name_mapping.get(command, command))

@override
def execution_callback(self) -> ExecutionCallback:
return self._on_stage_content

@override
def args_callback(self) -> ArgsCallback:
return self._args_callback
return ArgsCallback(self._on_stage_name)

@override
def on_result(self, result: ResultObject):
Expand Down
Loading
Loading