Skip to content

Commit

Permalink
Add tool calls api support, add search tool support
Browse files Browse the repository at this point in the history
  • Loading branch information
hlohaus committed Dec 28, 2024
1 parent 8bb9ddc commit d1f4093
Show file tree
Hide file tree
Showing 14 changed files with 450 additions and 283 deletions.
1 change: 1 addition & 0 deletions etc/unittest/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .image_client import *
from .include import *
from .retry_provider import *
from .web_search import *
from .models import *

unittest.main()
20 changes: 13 additions & 7 deletions etc/unittest/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,45 @@
class ProviderMock(AbstractProvider):
working = True

@classmethod
def create_completion(
model, messages, stream, **kwargs
cls, model, messages, stream, **kwargs
):
yield "Mock"

class AsyncProviderMock(AsyncProvider):
working = True

@classmethod
async def create_async(
model, messages, **kwargs
cls, model, messages, **kwargs
):
return "Mock"

class AsyncGeneratorProviderMock(AsyncGeneratorProvider):
working = True

@classmethod
async def create_async_generator(
model, messages, stream, **kwargs
cls, model, messages, stream, **kwargs
):
yield "Mock"

class ModelProviderMock(AbstractProvider):
working = True

@classmethod
def create_completion(
model, messages, stream, **kwargs
cls, model, messages, stream, **kwargs
):
yield model

class YieldProviderMock(AsyncGeneratorProvider):
working = True

@classmethod
async def create_async_generator(
model, messages, stream, **kwargs
cls, model, messages, stream, **kwargs
):
for message in messages:
yield message["content"]
Expand Down Expand Up @@ -84,8 +89,9 @@ async def create_async_generator(

class YieldNoneProviderMock(AsyncGeneratorProvider):
working = True


@classmethod
async def create_async_generator(
model, messages, stream, **kwargs
cls, model, messages, stream, **kwargs
):
yield None
82 changes: 82 additions & 0 deletions etc/unittest/web_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from __future__ import annotations

import json
import unittest

from g4f.client import AsyncClient
from g4f.web_search import DuckDuckGoSearchException, has_requirements
from .mocks import YieldProviderMock

DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]

class TestIterListProvider(unittest.IsolatedAsyncioTestCase):
def setUp(self) -> None:
if not has_requirements:
self.skipTest('web search requirements not passed')

async def test_search(self):
client = AsyncClient(provider=YieldProviderMock)
tool_calls = [
{
"function": {
"arguments": {
"query": "search query", # content of last message: messages[-1]["content"]
"max_results": 5, # maximum number of search results
"max_words": 500, # maximum number of used words from search results for generating the response
"backend": "html", # or "lite", "api": change it to pypass rate limits
"add_text": True, # do scraping websites
"timeout": 5, # in seconds for scraping websites
"region": "wt-wt",
"instructions": "Using the provided web search results, to write a comprehensive reply to the user request.\n"
"Make sure to add the sources of cites using [[Number]](Url) notation after the reference. Example: [[0]](http://google.com)",
},
"name": "search_tool"
},
"type": "function"
}
]
try:
response = await client.chat.completions.create([{"content": "", "role": "user"}], "", tool_calls=tool_calls)
self.assertIn("Instruction: Using the provided web search results", response.choices[0].message.content)
except DuckDuckGoSearchException as e:
self.skipTest(f'DuckDuckGoSearchException: {e}')

async def test_search2(self):
client = AsyncClient(provider=YieldProviderMock)
tool_calls = [
{
"function": {
"arguments": {
"query": "search query",
},
"name": "search_tool"
},
"type": "function"
}
]
try:
response = await client.chat.completions.create([{"content": "", "role": "user"}], "", tool_calls=tool_calls)
self.assertIn("Instruction: Using the provided web search results", response.choices[0].message.content)
except DuckDuckGoSearchException as e:
self.skipTest(f'DuckDuckGoSearchException: {e}')

async def test_search3(self):
client = AsyncClient(provider=YieldProviderMock)
tool_calls = [
{
"function": {
"arguments": json.dumps({
"query": "search query", # content of last message: messages[-1]["content"]
"max_results": 5, # maximum number of search results
"max_words": 500, # maximum number of used words from search results for generating the response
}),
"name": "search_tool"
},
"type": "function"
}
]
try:
response = await client.chat.completions.create([{"content": "", "role": "user"}], "", tool_calls=tool_calls)
self.assertIn("Instruction: Using the provided web search results", response.choices[0].message.content)
except DuckDuckGoSearchException as e:
self.skipTest(f'DuckDuckGoSearchException: {e}')
58 changes: 16 additions & 42 deletions g4f/Provider/Mhystical.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
from __future__ import annotations

import json
import logging
from aiohttp import ClientSession
from ..typing import AsyncResult, Messages
from ..requests.raise_for_status import raise_for_status
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .helper import format_prompt
from .needs_auth.OpenaiAPI import OpenaiAPI

"""
Mhystical.cc
Expand All @@ -19,39 +14,31 @@
"""

logger = logging.getLogger(__name__)

class Mhystical(AsyncGeneratorProvider, ProviderModelMixin):
class Mhystical(OpenaiAPI):
url = "https://api.mhystical.cc"
api_endpoint = "https://api.mhystical.cc/v1/completions"
working = True
needs_auth = False
supports_stream = False # Set to False, as streaming is not specified in ChatifyAI
supports_system_message = False
supports_message_history = True

default_model = 'gpt-4'
models = [default_model]
model_aliases = {}

@classmethod
def get_model(cls, model: str) -> str:
if model in cls.models:
return model
elif model in cls.model_aliases:
return cls.model_aliases.get(model, cls.default_model)
else:
return cls.default_model
def get_model(cls, model: str, **kwargs) -> str:
cls.last_model = cls.default_model
return cls.default_model

@classmethod
async def create_async_generator(
def create_async_generator(
cls,
model: str,
messages: Messages,
proxy: str = None,
stream: bool = False,
**kwargs
) -> AsyncResult:
model = cls.get_model(model)

headers = {
"x-api-key": "mhystical",
"Content-Type": "application/json",
Expand All @@ -61,24 +48,11 @@ async def create_async_generator(
"referer": f"{cls.url}/",
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36"
}

async with ClientSession(headers=headers) as session:
data = {
"model": model,
"messages": [{"role": "user", "content": format_prompt(messages)}]
}
async with session.post(cls.api_endpoint, json=data, headers=headers, proxy=proxy) as response:
await raise_for_status(response)
response_text = await response.text()
filtered_response = cls.filter_response(response_text)
yield filtered_response

@staticmethod
def filter_response(response_text: str) -> str:
try:
json_response = json.loads(response_text)
message_content = json_response["choices"][0]["message"]["content"]
return message_content
except (KeyError, IndexError, json.JSONDecodeError) as e:
logger.error("Error parsing response: %s", e)
return "Error: Failed to parse response from API."
return super().create_async_generator(
model=model,
messages=messages,
stream=cls.supports_stream,
api_endpoint=cls.api_endpoint,
headers=headers,
**kwargs
)
15 changes: 12 additions & 3 deletions g4f/Provider/needs_auth/OpenaiAPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ...typing import Union, Optional, AsyncResult, Messages, ImagesType
from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason
from ...providers.response import FinishReason, ToolCalls, Usage
from ...errors import MissingAuthError, ResponseError
from ...image import to_data_uri
from ... import debug
Expand Down Expand Up @@ -51,6 +51,7 @@ async def create_async_generator(
timeout: int = 120,
images: ImagesType = None,
api_key: str = None,
api_endpoint: str = None,
api_base: str = None,
temperature: float = None,
max_tokens: int = None,
Expand All @@ -59,6 +60,7 @@ async def create_async_generator(
stream: bool = False,
headers: dict = None,
impersonate: str = None,
tools: Optional[list] = None,
extra_data: dict = {},
**kwargs
) -> AsyncResult:
Expand Down Expand Up @@ -93,16 +95,23 @@ async def create_async_generator(
top_p=top_p,
stop=stop,
stream=stream,
tools=tools,
**extra_data
)
async with session.post(f"{api_base.rstrip('/')}/chat/completions", json=data) as response:
if api_endpoint is None:
api_endpoint = f"{api_base.rstrip('/')}/chat/completions"
async with session.post(api_endpoint, json=data) as response:
await raise_for_status(response)
if not stream:
data = await response.json()
cls.raise_error(data)
choice = data["choices"][0]
if "content" in choice["message"]:
if "content" in choice["message"] and choice["message"]["content"]:
yield choice["message"]["content"].strip()
elif "tool_calls" in choice["message"]:
yield ToolCalls(choice["message"]["tool_calls"])
if "usage" in data:
yield Usage(**data["usage"])
finish = cls.read_finish_reason(choice)
if finish is not None:
yield finish
Expand Down
7 changes: 1 addition & 6 deletions g4f/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@
from starlette._compat import md5_hexdigest
from types import SimpleNamespace
from typing import Union, Optional, List
try:
from typing import Annotated
except ImportError:
class Annotated:
pass

import g4f
import g4f.debug
Expand All @@ -50,7 +45,7 @@ class Annotated:
ChatCompletionsConfig, ImageGenerationConfig,
ProviderResponseModel, ModelResponseModel,
ErrorResponseModel, ProviderResponseDetailModel,
FileResponseModel
FileResponseModel, Annotated
)

logger = logging.getLogger(__name__)
Expand Down
16 changes: 15 additions & 1 deletion g4f/api/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

from pydantic import BaseModel, Field
from typing import Union, Optional

try:
from typing import Annotated
except ImportError:
class Annotated:
pass
from g4f.typing import Messages

class ChatCompletionsConfig(BaseModel):
Expand All @@ -23,6 +27,16 @@ class ChatCompletionsConfig(BaseModel):
history_disabled: Optional[bool] = None
auto_continue: Optional[bool] = None
timeout: Optional[int] = None
tool_calls: list = Field(default=[], examples=[[
{
"function": {
"arguments": {"query":"search query", "max_results":5, "max_words": 2500, "backend": "api", "add_text": True, "timeout": 5},
"name": "search_tool"
},
"type": "function"
}
]])
tools: list = None

class ImageGenerationConfig(BaseModel):
prompt: str
Expand Down
Loading

0 comments on commit d1f4093

Please sign in to comment.