Skip to content

Commit

Permalink
feat: added helpers for tool_calls/function_call
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik committed Dec 12, 2023
1 parent cca9d4d commit 672bb95
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 4 deletions.
21 changes: 20 additions & 1 deletion aidial_sdk/chat_completion/choice.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from asyncio import Queue
from types import TracebackType
from typing import Any, Optional, Type
from typing import Any, List, Optional, Type

from aidial_sdk.chat_completion.chunks import (
AttachmentChunk,
ContentChunk,
EndChoiceChunk,
FunctionCallChunk,
StartChoiceChunk,
StateChunk,
ToolCallsChunk,
)
from aidial_sdk.chat_completion.enums import FinishReason
from aidial_sdk.chat_completion.request import FunctionCall, ToolCall
from aidial_sdk.chat_completion.stage import Stage
from aidial_sdk.pydantic_v1 import ValidationError
from aidial_sdk.utils.errors import runtime_error
Expand Down Expand Up @@ -54,6 +57,22 @@ def append_content(self, content: str) -> None:

self._queue.put_nowait(ContentChunk(content, self._index))

def add_tool_calls(self, tool_calls: List[ToolCall]) -> None:
if not self._opened:
runtime_error("Trying to add tool call to an unopened choice")
if self._closed:
runtime_error("Trying to add tool call to a closed choice")

self._queue.put_nowait(ToolCallsChunk(tool_calls, self._index))

def add_function_call(self, function_call: FunctionCall) -> None:
if not self._opened:
runtime_error("Trying to add function call to an unopened choice")
if self._closed:
runtime_error("Trying to add function call to a closed choice")

self._queue.put_nowait(FunctionCallChunk(function_call, self._index))

def add_attachment(
self,
type: Optional[str] = None,
Expand Down
47 changes: 46 additions & 1 deletion aidial_sdk/chat_completion/chunks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

from aidial_sdk.chat_completion.enums import FinishReason, Status
from aidial_sdk.chat_completion.request import FunctionCall, ToolCall
from aidial_sdk.pydantic_v1 import BaseModel, root_validator


Expand Down Expand Up @@ -72,6 +73,50 @@ def to_dict(self):
}


class ToolCallsChunk(BaseChunk):
tool_calls: List[ToolCall]
choice_index: int

def __init__(self, tool_calls: List[ToolCall], choice_index: int):
self.tool_calls = tool_calls
self.choice_index = choice_index

def to_dict(self):
return {
"choices": [
{
"index": self.choice_index,
"finish_reason": "tool_calls",
"delta": {
"tool_calls": [c.dict() for c in self.tool_calls]
},
}
],
"usage": None,
}


class FunctionCallChunk(BaseChunk):
function_call: FunctionCall
choice_index: int

def __init__(self, function_call: FunctionCall, choice_index: int):
self.function_call = function_call
self.choice_index = choice_index

def to_dict(self):
return {
"choices": [
{
"index": self.choice_index,
"finish_reason": "function_call",
"delta": {"function_call": self.function_call.dict()},
}
],
"usage": None,
}


class StartStageChunk(BaseChunk):
choice_index: int
stage_index: int
Expand Down
4 changes: 2 additions & 2 deletions aidial_sdk/chat_completion/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ class AssistantMessage(ExtraForbidModel):

class ToolMessage(ExtraForbidModel):
role: Literal["tool"]
content: StrictStr
tool_call_id: StrictStr
content: StrictStr


class FunctionMessage(ExtraForbidModel):
role: Literal["function"]
content: StrictStr
name: StrictStr
content: StrictStr


Message = Annotated[
Expand Down

0 comments on commit 672bb95

Please sign in to comment.