From d1f4093e1ab5ee957afa63368f4c935374a53121 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Sat, 28 Dec 2024 16:21:14 +0100 Subject: [PATCH] Add tool calls api support, add search tool support --- etc/unittest/__main__.py | 1 + etc/unittest/mocks.py | 20 ++-- etc/unittest/web_search.py | 82 +++++++++++++ g4f/Provider/Mhystical.py | 58 +++------ g4f/Provider/needs_auth/OpenaiAPI.py | 15 ++- g4f/api/__init__.py | 7 +- g4f/api/stubs.py | 16 ++- g4f/client/__init__.py | 122 ++++++++++++++----- g4f/client/stubs.py | 28 +++-- g4f/gui/server/api.py | 2 +- g4f/gui/server/internet.py | 159 +------------------------ g4f/image.py | 5 +- g4f/providers/response.py | 46 ++++--- g4f/web_search.py | 172 +++++++++++++++++++++++++++ 14 files changed, 450 insertions(+), 283 deletions(-) create mode 100644 etc/unittest/web_search.py create mode 100644 g4f/web_search.py diff --git a/etc/unittest/__main__.py b/etc/unittest/__main__.py index 6594e6a2677..5a29da34b43 100644 --- a/etc/unittest/__main__.py +++ b/etc/unittest/__main__.py @@ -8,6 +8,7 @@ from .image_client import * from .include import * from .retry_provider import * +from .web_search import * from .models import * unittest.main() \ No newline at end of file diff --git a/etc/unittest/mocks.py b/etc/unittest/mocks.py index c43d98ccaf9..50d1a5a4075 100644 --- a/etc/unittest/mocks.py +++ b/etc/unittest/mocks.py @@ -5,40 +5,45 @@ class ProviderMock(AbstractProvider): working = True + @classmethod def create_completion( - model, messages, stream, **kwargs + cls, model, messages, stream, **kwargs ): yield "Mock" class AsyncProviderMock(AsyncProvider): working = True + @classmethod async def create_async( - model, messages, **kwargs + cls, model, messages, **kwargs ): return "Mock" class AsyncGeneratorProviderMock(AsyncGeneratorProvider): working = True + @classmethod async def create_async_generator( - model, messages, stream, **kwargs + cls, model, messages, stream, **kwargs ): yield "Mock" class ModelProviderMock(AbstractProvider): working = True + @classmethod def create_completion( - model, messages, stream, **kwargs + cls, model, messages, stream, **kwargs ): yield model class YieldProviderMock(AsyncGeneratorProvider): working = True + @classmethod async def create_async_generator( - model, messages, stream, **kwargs + cls, model, messages, stream, **kwargs ): for message in messages: yield message["content"] @@ -84,8 +89,9 @@ async def create_async_generator( class YieldNoneProviderMock(AsyncGeneratorProvider): working = True - + + @classmethod async def create_async_generator( - model, messages, stream, **kwargs + cls, model, messages, stream, **kwargs ): yield None \ No newline at end of file diff --git a/etc/unittest/web_search.py b/etc/unittest/web_search.py new file mode 100644 index 00000000000..2af369fd4c0 --- /dev/null +++ b/etc/unittest/web_search.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import json +import unittest + +from g4f.client import AsyncClient +from g4f.web_search import DuckDuckGoSearchException, has_requirements +from .mocks import YieldProviderMock + +DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}] + +class TestIterListProvider(unittest.IsolatedAsyncioTestCase): + def setUp(self) -> None: + if not has_requirements: + self.skipTest('web search requirements not passed') + + async def test_search(self): + client = AsyncClient(provider=YieldProviderMock) + tool_calls = [ + { + "function": { + "arguments": { + "query": "search query", # content of last message: messages[-1]["content"] + "max_results": 5, # maximum number of search results + "max_words": 500, # maximum number of used words from search results for generating the response + "backend": "html", # or "lite", "api": change it to pypass rate limits + "add_text": True, # do scraping websites + "timeout": 5, # in seconds for scraping websites + "region": "wt-wt", + "instructions": "Using the provided web search results, to write a comprehensive reply to the user request.\n" + "Make sure to add the sources of cites using [[Number]](Url) notation after the reference. Example: [[0]](http://google.com)", + }, + "name": "search_tool" + }, + "type": "function" + } + ] + try: + response = await client.chat.completions.create([{"content": "", "role": "user"}], "", tool_calls=tool_calls) + self.assertIn("Instruction: Using the provided web search results", response.choices[0].message.content) + except DuckDuckGoSearchException as e: + self.skipTest(f'DuckDuckGoSearchException: {e}') + + async def test_search2(self): + client = AsyncClient(provider=YieldProviderMock) + tool_calls = [ + { + "function": { + "arguments": { + "query": "search query", + }, + "name": "search_tool" + }, + "type": "function" + } + ] + try: + response = await client.chat.completions.create([{"content": "", "role": "user"}], "", tool_calls=tool_calls) + self.assertIn("Instruction: Using the provided web search results", response.choices[0].message.content) + except DuckDuckGoSearchException as e: + self.skipTest(f'DuckDuckGoSearchException: {e}') + + async def test_search3(self): + client = AsyncClient(provider=YieldProviderMock) + tool_calls = [ + { + "function": { + "arguments": json.dumps({ + "query": "search query", # content of last message: messages[-1]["content"] + "max_results": 5, # maximum number of search results + "max_words": 500, # maximum number of used words from search results for generating the response + }), + "name": "search_tool" + }, + "type": "function" + } + ] + try: + response = await client.chat.completions.create([{"content": "", "role": "user"}], "", tool_calls=tool_calls) + self.assertIn("Instruction: Using the provided web search results", response.choices[0].message.content) + except DuckDuckGoSearchException as e: + self.skipTest(f'DuckDuckGoSearchException: {e}') \ No newline at end of file diff --git a/g4f/Provider/Mhystical.py b/g4f/Provider/Mhystical.py index 14412c07c5b..380da18d3c8 100644 --- a/g4f/Provider/Mhystical.py +++ b/g4f/Provider/Mhystical.py @@ -1,12 +1,7 @@ from __future__ import annotations -import json -import logging -from aiohttp import ClientSession from ..typing import AsyncResult, Messages -from ..requests.raise_for_status import raise_for_status -from .base_provider import AsyncGeneratorProvider, ProviderModelMixin -from .helper import format_prompt +from .needs_auth.OpenaiAPI import OpenaiAPI """ Mhystical.cc @@ -19,39 +14,31 @@ """ -logger = logging.getLogger(__name__) - -class Mhystical(AsyncGeneratorProvider, ProviderModelMixin): +class Mhystical(OpenaiAPI): url = "https://api.mhystical.cc" api_endpoint = "https://api.mhystical.cc/v1/completions" working = True + needs_auth = False supports_stream = False # Set to False, as streaming is not specified in ChatifyAI supports_system_message = False - supports_message_history = True default_model = 'gpt-4' models = [default_model] - model_aliases = {} @classmethod - def get_model(cls, model: str) -> str: - if model in cls.models: - return model - elif model in cls.model_aliases: - return cls.model_aliases.get(model, cls.default_model) - else: - return cls.default_model + def get_model(cls, model: str, **kwargs) -> str: + cls.last_model = cls.default_model + return cls.default_model @classmethod - async def create_async_generator( + def create_async_generator( cls, model: str, messages: Messages, - proxy: str = None, + stream: bool = False, **kwargs ) -> AsyncResult: model = cls.get_model(model) - headers = { "x-api-key": "mhystical", "Content-Type": "application/json", @@ -61,24 +48,11 @@ async def create_async_generator( "referer": f"{cls.url}/", "user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36" } - - async with ClientSession(headers=headers) as session: - data = { - "model": model, - "messages": [{"role": "user", "content": format_prompt(messages)}] - } - async with session.post(cls.api_endpoint, json=data, headers=headers, proxy=proxy) as response: - await raise_for_status(response) - response_text = await response.text() - filtered_response = cls.filter_response(response_text) - yield filtered_response - - @staticmethod - def filter_response(response_text: str) -> str: - try: - json_response = json.loads(response_text) - message_content = json_response["choices"][0]["message"]["content"] - return message_content - except (KeyError, IndexError, json.JSONDecodeError) as e: - logger.error("Error parsing response: %s", e) - return "Error: Failed to parse response from API." + return super().create_async_generator( + model=model, + messages=messages, + stream=cls.supports_stream, + api_endpoint=cls.api_endpoint, + headers=headers, + **kwargs + ) \ No newline at end of file diff --git a/g4f/Provider/needs_auth/OpenaiAPI.py b/g4f/Provider/needs_auth/OpenaiAPI.py index ec5f491f94b..6471895e6e5 100644 --- a/g4f/Provider/needs_auth/OpenaiAPI.py +++ b/g4f/Provider/needs_auth/OpenaiAPI.py @@ -7,7 +7,7 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ...typing import Union, Optional, AsyncResult, Messages, ImagesType from ...requests import StreamSession, raise_for_status -from ...providers.response import FinishReason +from ...providers.response import FinishReason, ToolCalls, Usage from ...errors import MissingAuthError, ResponseError from ...image import to_data_uri from ... import debug @@ -51,6 +51,7 @@ async def create_async_generator( timeout: int = 120, images: ImagesType = None, api_key: str = None, + api_endpoint: str = None, api_base: str = None, temperature: float = None, max_tokens: int = None, @@ -59,6 +60,7 @@ async def create_async_generator( stream: bool = False, headers: dict = None, impersonate: str = None, + tools: Optional[list] = None, extra_data: dict = {}, **kwargs ) -> AsyncResult: @@ -93,16 +95,23 @@ async def create_async_generator( top_p=top_p, stop=stop, stream=stream, + tools=tools, **extra_data ) - async with session.post(f"{api_base.rstrip('/')}/chat/completions", json=data) as response: + if api_endpoint is None: + api_endpoint = f"{api_base.rstrip('/')}/chat/completions" + async with session.post(api_endpoint, json=data) as response: await raise_for_status(response) if not stream: data = await response.json() cls.raise_error(data) choice = data["choices"][0] - if "content" in choice["message"]: + if "content" in choice["message"] and choice["message"]["content"]: yield choice["message"]["content"].strip() + elif "tool_calls" in choice["message"]: + yield ToolCalls(choice["message"]["tool_calls"]) + if "usage" in data: + yield Usage(**data["usage"]) finish = cls.read_finish_reason(choice) if finish is not None: yield finish diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py index 093cc6a3392..6f3843da3cf 100644 --- a/g4f/api/__init__.py +++ b/g4f/api/__init__.py @@ -30,11 +30,6 @@ from starlette._compat import md5_hexdigest from types import SimpleNamespace from typing import Union, Optional, List -try: - from typing import Annotated -except ImportError: - class Annotated: - pass import g4f import g4f.debug @@ -50,7 +45,7 @@ class Annotated: ChatCompletionsConfig, ImageGenerationConfig, ProviderResponseModel, ModelResponseModel, ErrorResponseModel, ProviderResponseDetailModel, - FileResponseModel + FileResponseModel, Annotated ) logger = logging.getLogger(__name__) diff --git a/g4f/api/stubs.py b/g4f/api/stubs.py index 8610f6b0e32..bd48e0bf950 100644 --- a/g4f/api/stubs.py +++ b/g4f/api/stubs.py @@ -2,7 +2,11 @@ from pydantic import BaseModel, Field from typing import Union, Optional - +try: + from typing import Annotated +except ImportError: + class Annotated: + pass from g4f.typing import Messages class ChatCompletionsConfig(BaseModel): @@ -23,6 +27,16 @@ class ChatCompletionsConfig(BaseModel): history_disabled: Optional[bool] = None auto_continue: Optional[bool] = None timeout: Optional[int] = None + tool_calls: list = Field(default=[], examples=[[ + { + "function": { + "arguments": {"query":"search query", "max_results":5, "max_words": 2500, "backend": "api", "add_text": True, "timeout": 5}, + "name": "search_tool" + }, + "type": "function" + } + ]]) + tools: list = None class ImageGenerationConfig(BaseModel): prompt: str diff --git a/g4f/client/__init__.py b/g4f/client/__init__.py index e717f41c79e..b48f8036e03 100644 --- a/g4f/client/__init__.py +++ b/g4f/client/__init__.py @@ -6,20 +6,22 @@ import string import asyncio import base64 +import json from typing import Union, AsyncIterator, Iterator, Coroutine, Optional from ..image import ImageResponse, copy_images, images_dir from ..typing import Messages, ImageType -from ..providers.types import ProviderType -from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData +from ..providers.types import ProviderType, BaseRetryProvider +from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData, ToolCalls, Usage from ..errors import NoImageResponseError from ..providers.retry_provider import IterListProvider from ..providers.asyncio import to_sync_generator, async_generator_to_list +from ..web_search import get_search_message, do_search from ..Provider.needs_auth import BingCreateImages, OpenaiAccount from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse from .image_models import ImageModels from .types import IterResponse, ImageProvider, Client as BaseClient -from .service import get_model_and_provider, get_last_provider, convert_to_provider +from .service import get_model_and_provider, convert_to_provider from .helper import find_stop, filter_json, filter_none, safe_aclose, to_async_iterator from .. import debug @@ -35,6 +37,47 @@ async def anext(aiter): except StopAsyncIteration: raise StopIteration +def validate_arguments(data: dict): + if "arguments" in data: + if isinstance(data["arguments"], str): + data["arguments"] = json.loads(data["arguments"]) + if not isinstance(data["arguments"], dict): + raise ValueError("Tool function arguments must be a dictionary or a json string") + else: + return filter_none(**data["arguments"]) + else: + return {} + +async def async_iter_run_tools(async_iter_callback, model, messages, tool_calls: Optional[list] = None, **kwargs): + if tool_calls is not None: + for tool in tool_calls: + if tool.get("type") == "function": + if tool.get("function", {}).get("name") == "search_tool": + tool["function"]["arguments"] = validate_arguments(tool["function"]) + messages = messages.copy() + messages[-1]["content"] = await do_search( + messages[-1]["content"], + **tool["function"]["arguments"] + ) + response = async_iter_callback(model=model, messages=messages, **kwargs) + if not hasattr(response, "__aiter__"): + response = to_async_iterator(response) + async for chunk in response: + yield chunk + +def iter_run_tools(iter_callback, model, messages, tool_calls: Optional[list] = None, **kwargs): + if tool_calls is not None: + for tool in tool_calls: + if tool.get("type") == "function": + if tool.get("function", {}).get("name") == "search_tool": + tool["function"]["arguments"] = validate_arguments(tool["function"]) + messages[-1]["content"] = get_search_message( + messages[-1]["content"], + raise_search_exceptions=True, + **tool["function"]["arguments"] + ) + return iter_callback(model=model, messages=messages, **kwargs) + # Synchronous iter_response function def iter_response( response: Union[Iterator[Union[str, ResponseType]]], @@ -45,6 +88,8 @@ def iter_response( ) -> ChatCompletionResponseType: content = "" finish_reason = None + tool_calls = None + usage = None completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28)) idx = 0 @@ -55,6 +100,12 @@ def iter_response( if isinstance(chunk, FinishReason): finish_reason = chunk.reason break + elif isinstance(chunk, ToolCalls): + tool_calls = chunk.get_list() + continue + elif isinstance(chunk, Usage): + usage = chunk.get_dict() + continue elif isinstance(chunk, BaseConversation): yield chunk continue @@ -88,18 +139,21 @@ def iter_response( if response_format is not None and "type" in response_format: if response_format["type"] == "json_object": content = filter_json(content) - yield ChatCompletion.model_construct(content, finish_reason, completion_id, int(time.time())) + yield ChatCompletion.model_construct(content, finish_reason, completion_id, int(time.time()), **filter_none( + tool_calls=tool_calls, + usage=usage + )) # Synchronous iter_append_model_and_provider function -def iter_append_model_and_provider(response: ChatCompletionResponseType) -> ChatCompletionResponseType: - last_provider = None - +def iter_append_model_and_provider(response: ChatCompletionResponseType, last_model: str, last_provider: ProviderType) -> ChatCompletionResponseType: + if isinstance(last_provider, BaseRetryProvider): + last_provider = last_provider.last_provider for chunk in response: if isinstance(chunk, (ChatCompletion, ChatCompletionChunk)): - last_provider = get_last_provider(True) if last_provider is None else last_provider - chunk.model = last_provider.get("model") - chunk.provider = last_provider.get("name") - yield chunk + if last_provider is not None: + chunk.model = getattr(last_provider, "last_model", last_model) + chunk.provider = last_provider.__name__ + yield chunk async def async_iter_response( response: AsyncIterator[Union[str, ResponseType]], @@ -155,15 +209,20 @@ async def async_iter_response( await safe_aclose(response) async def async_iter_append_model_and_provider( - response: AsyncChatCompletionResponseType + response: AsyncChatCompletionResponseType, + last_model: str, + last_provider: ProviderType ) -> AsyncChatCompletionResponseType: last_provider = None try: + if isinstance(last_provider, BaseRetryProvider): + if last_provider is not None: + last_provider = last_provider.last_provider async for chunk in response: if isinstance(chunk, (ChatCompletion, ChatCompletionChunk)): - last_provider = get_last_provider(True) if last_provider is None else last_provider - chunk.model = last_provider.get("model") - chunk.provider = last_provider.get("name") + if last_provider is not None: + chunk.model = getattr(last_provider, "last_model", last_model) + chunk.provider = last_provider.__name__ yield chunk finally: await safe_aclose(response) @@ -215,7 +274,9 @@ def create( kwargs["images"] = [(image, image_name)] if ignore_stream: kwargs["ignore_stream"] = True - response = provider.create_completion( + + response = iter_run_tools( + provider.create_completion, model, messages, stream=stream, @@ -237,7 +298,7 @@ def create( # If response is an async generator, collect it into a list response = asyncio.run(async_generator_to_list(response)) response = iter_response(response, stream, response_format, max_tokens, stop) - response = iter_append_model_and_provider(response) + response = iter_append_model_and_provider(response, model, provider) if stream: return response else: @@ -314,10 +375,10 @@ async def async_generate( if isinstance(response, ImageResponse): return await self._process_image_response( response, - response_format, - proxy, model, - provider_name + provider_name, + response_format, + proxy ) if response is None: if error is not None: @@ -407,7 +468,7 @@ async def async_create_variation( response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs) if isinstance(response, ImageResponse): - return await self._process_image_response(response, response_format, proxy, model, provider_name) + return await self._process_image_response(response, model, provider_name, response_format, proxy) if response is None: if error is not None: raise error @@ -417,12 +478,11 @@ async def async_create_variation( async def _process_image_response( self, response: ImageResponse, + model: str, + provider: str, response_format: Optional[str] = None, - proxy: str = None, - model: Optional[str] = None, - provider: Optional[str] = None + proxy: str = None ) -> ImagesResponse: - last_provider = get_last_provider(True) if response_format == "url": # Return original URLs without saving locally images = [Image.model_construct(url=image, revised_prompt=response.alt) for image in response.get_list()] @@ -440,8 +500,8 @@ async def process_image_item(image_file: str) -> Image: return ImagesResponse.model_construct( created=int(time.time()), data=images, - model=last_provider.get("model") if model is None else model, - provider=last_provider.get("name") if provider is None else provider + model=model, + provider=provider ) @@ -502,7 +562,8 @@ def create( create_handler = provider.create_async_generator else: create_handler = provider.create_completion - response = create_handler( + response = async_iter_run_tools( + create_handler, model, messages, stream=stream, @@ -514,11 +575,8 @@ def create( ), **kwargs ) - - if not hasattr(response, "__aiter__"): - response = to_async_iterator(response) response = async_iter_response(response, stream, response_format, max_tokens, stop) - response = async_iter_append_model_and_provider(response) + response = async_iter_append_model_and_provider(response, model, provider) return response if stream else anext(response) class AsyncImages(Images): diff --git a/g4f/client/stubs.py b/g4f/client/stubs.py index 575327690ed..8f3425de9b4 100644 --- a/g4f/client/stubs.py +++ b/g4f/client/stubs.py @@ -1,10 +1,13 @@ from __future__ import annotations -from typing import Optional, List, Dict +from typing import Optional, List, Dict, Any from time import time from .helper import filter_none +ToolCalls = Optional[List[Dict[str, Any]]] +Usage = Optional[Dict[str, int]] + try: from pydantic import BaseModel, Field except ImportError: @@ -57,10 +60,11 @@ def model_construct( class ChatCompletionMessage(BaseModel): role: str content: str + tool_calls: ToolCalls @classmethod - def model_construct(cls, content: str): - return super().model_construct(role="assistant", content=content) + def model_construct(cls, content: str, tool_calls: ToolCalls = None): + return super().model_construct(role="assistant", content=content, **filter_none(tool_calls=tool_calls)) class ChatCompletionChoice(BaseModel): index: int @@ -78,11 +82,11 @@ class ChatCompletion(BaseModel): model: str provider: Optional[str] choices: List[ChatCompletionChoice] - usage: Dict[str, int] = Field(examples=[{ + usage: Usage = Field(default={ "prompt_tokens": 0, #prompt_tokens, "completion_tokens": 0, #completion_tokens, "total_tokens": 0, #prompt_tokens + completion_tokens, - }]) + }) @classmethod def model_construct( @@ -90,7 +94,9 @@ def model_construct( content: str, finish_reason: str, completion_id: str = None, - created: int = None + created: int = None, + tool_calls: ToolCalls = None, + usage: Usage = None ): return super().model_construct( id=f"chatcmpl-{completion_id}" if completion_id else None, @@ -99,14 +105,10 @@ def model_construct( model=None, provider=None, choices=[ChatCompletionChoice.model_construct( - ChatCompletionMessage.model_construct(content), - finish_reason + ChatCompletionMessage.model_construct(content, tool_calls), + finish_reason, )], - usage={ - "prompt_tokens": 0, #prompt_tokens, - "completion_tokens": 0, #completion_tokens, - "total_tokens": 0, #prompt_tokens + completion_tokens, - } + **filter_none(usage=usage) ) class ChatCompletionDelta(BaseModel): diff --git a/g4f/gui/server/api.py b/g4f/gui/server/api.py index dcbf33228e7..d9a886af2df 100644 --- a/g4f/gui/server/api.py +++ b/g4f/gui/server/api.py @@ -97,7 +97,7 @@ def _prepare_conversation_kwargs(self, json_data: dict, kwargs: dict): kwargs['web_search'] = True do_web_search = False if do_web_search: - from .internet import get_search_message + from ...web_search import get_search_message messages[-1]["content"] = get_search_message(messages[-1]["content"]) if json_data.get("auto_continue"): kwargs['auto_continue'] = True diff --git a/g4f/gui/server/internet.py b/g4f/gui/server/internet.py index 96496f6552f..47a8556bfdd 100644 --- a/g4f/gui/server/internet.py +++ b/g4f/gui/server/internet.py @@ -1,160 +1,3 @@ from __future__ import annotations -from aiohttp import ClientSession, ClientTimeout -try: - from duckduckgo_search import DDGS - from bs4 import BeautifulSoup - has_requirements = True -except ImportError: - has_requirements = False -from ...errors import MissingRequirementsError -from ... import debug - -import asyncio - -class SearchResults(): - def __init__(self, results: list, used_words: int): - self.results = results - self.used_words = used_words - - def __iter__(self): - yield from self.results - - def __str__(self): - search = "" - for idx, result in enumerate(self.results): - if search: - search += "\n\n\n" - search += f"Title: {result.title}\n\n" - if result.text: - search += result.text - else: - search += result.snippet - search += f"\n\nSource: [[{idx}]]({result.url})" - return search - - def __len__(self) -> int: - return len(self.results) - -class SearchResultEntry(): - def __init__(self, title: str, url: str, snippet: str, text: str = None): - self.title = title - self.url = url - self.snippet = snippet - self.text = text - - def set_text(self, text: str): - self.text = text - -def scrape_text(html: str, max_words: int = None) -> str: - soup = BeautifulSoup(html, "html.parser") - for selector in [ - "main", - ".main-content-wrapper", - ".main-content", - ".emt-container-inner", - ".content-wrapper", - "#content", - "#mainContent", - ]: - select = soup.select_one(selector) - if select: - soup = select - break - # Zdnet - for remove in [".c-globalDisclosure"]: - select = soup.select_one(remove) - if select: - select.extract() - clean_text = "" - for paragraph in soup.select("p, h1, h2, h3, h4, h5, h6"): - text = paragraph.get_text() - for line in text.splitlines(): - words = [] - for word in line.replace("\t", " ").split(" "): - if word: - words.append(word) - count = len(words) - if not count: - continue - if max_words: - max_words -= count - if max_words <= 0: - break - if clean_text: - clean_text += "\n" - clean_text += " ".join(words) - - return clean_text - -async def fetch_and_scrape(session: ClientSession, url: str, max_words: int = None) -> str: - try: - async with session.get(url) as response: - if response.status == 200: - html = await response.text() - return scrape_text(html, max_words) - except: - return - -async def search(query: str, n_results: int = 5, max_words: int = 2500, add_text: bool = True) -> SearchResults: - if not has_requirements: - raise MissingRequirementsError('Install "duckduckgo-search" and "beautifulsoup4" package | pip install -U g4f[search]') - with DDGS() as ddgs: - results = [] - for result in ddgs.text( - query, - region="wt-wt", - safesearch="moderate", - timelimit="y", - max_results=n_results, - ): - results.append(SearchResultEntry( - result["title"], - result["href"], - result["body"] - )) - - if add_text: - requests = [] - async with ClientSession(timeout=ClientTimeout(5)) as session: - for entry in results: - requests.append(fetch_and_scrape(session, entry.url, int(max_words / (n_results - 1)))) - texts = await asyncio.gather(*requests) - - formatted_results = [] - used_words = 0 - left_words = max_words - for i, entry in enumerate(results): - if add_text: - entry.text = texts[i] - if left_words: - left_words -= entry.title.count(" ") + 5 - if entry.text: - left_words -= entry.text.count(" ") - else: - left_words -= entry.snippet.count(" ") - if 0 > left_words: - break - used_words = max_words - left_words - formatted_results.append(entry) - - return SearchResults(formatted_results, used_words) - -def get_search_message(prompt, n_results: int = 5, max_words: int = 2500) -> str: - try: - search_results = asyncio.run(search(prompt, n_results, max_words)) - message = f""" -{search_results} - - -Instruction: Using the provided web search results, to write a comprehensive reply to the user request. -Make sure to add the sources of cites using [[Number]](Url) notation after the reference. Example: [[0]](http://google.com) - -User request: -{prompt} -""" - debug.log(f"Web search: '{prompt.strip()[:50]}...' {search_results.used_words} Words") - return message - except Exception as e: - debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}") - return prompt \ No newline at end of file +from ...web_search import SearchResults, search, get_search_message \ No newline at end of file diff --git a/g4f/image.py b/g4f/image.py index 95d410fed84..8fae5aedbdc 100644 --- a/g4f/image.py +++ b/g4f/image.py @@ -222,12 +222,11 @@ def to_bytes(image: ImageType) -> bytes: elif isinstance(image, Path): return image.read_bytes() else: - fp = open(image, "rb") try: - fp.seek(0) + image.seek(0) except (AttributeError, io.UnsupportedOperation): pass - return fp.read() + return image.read() def to_data_uri(image: ImageType) -> str: if not isinstance(image, str): diff --git a/g4f/providers/response.py b/g4f/providers/response.py index 44020bedf6f..4224436fec9 100644 --- a/g4f/providers/response.py +++ b/g4f/providers/response.py @@ -71,6 +71,18 @@ class ResponseType: def __str__(self) -> str: pass +class JsonMixin: + def __init__(self, **kwargs) -> None: + for key, value in kwargs.items(): + setattr(self, key, value) + + def get_dict(self): + return { + key: value + for key, value in self.__dict__.items() + if not key.startswith("__") + } + class FinishReason(): def __init__(self, reason: str): self.reason = reason @@ -78,6 +90,20 @@ def __init__(self, reason: str): def __str__(self) -> str: return "" +class ToolCalls(ResponseType): + def __init__(self, list: list): + self.list = list + + def __str__(self) -> str: + return "" + + def get_list(self) -> list: + return self.list + +class Usage(ResponseType, JsonMixin): + def __str__(self) -> str: + return "" + class TitleGeneration(ResponseType): def __init__(self, title: str) -> None: self.title = title @@ -108,28 +134,14 @@ class BaseConversation(ResponseType): def __str__(self) -> str: return "" -class JsonConversation(BaseConversation): - def __init__(self, **kwargs) -> None: - for key, value in kwargs.items(): - setattr(self, key, value) +class JsonConversation(BaseConversation, JsonMixin): + pass - def to_dict(self): - return { - key: value - for key, value in self.__dict__.items() - if not key.startswith("__") - } - -class SynthesizeData(ResponseType): +class SynthesizeData(ResponseType, JsonMixin): def __init__(self, provider: str, data: dict): self.provider = provider self.data = data - def to_json(self) -> dict: - return { - **self.__dict__ - } - def __str__(self) -> str: return "" diff --git a/g4f/web_search.py b/g4f/web_search.py new file mode 100644 index 00000000000..652555b639e --- /dev/null +++ b/g4f/web_search.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +from aiohttp import ClientSession, ClientTimeout, ClientError +try: + from duckduckgo_search import DDGS + from duckduckgo_search.exceptions import DuckDuckGoSearchException + from bs4 import BeautifulSoup + has_requirements = True +except ImportError: + has_requirements = False +from .errors import MissingRequirementsError +from . import debug + +import asyncio + +DEFAULT_INSTRUCTIONS = """ +Using the provided web search results, to write a comprehensive reply to the user request. +Make sure to add the sources of cites using [[Number]](Url) notation after the reference. Example: [[0]](http://google.com) +""" + +class SearchResults(): + def __init__(self, results: list, used_words: int): + self.results = results + self.used_words = used_words + + def __iter__(self): + yield from self.results + + def __str__(self): + search = "" + for idx, result in enumerate(self.results): + if search: + search += "\n\n\n" + search += f"Title: {result.title}\n\n" + if result.text: + search += result.text + else: + search += result.snippet + search += f"\n\nSource: [[{idx}]]({result.url})" + return search + + def __len__(self) -> int: + return len(self.results) + +class SearchResultEntry(): + def __init__(self, title: str, url: str, snippet: str, text: str = None): + self.title = title + self.url = url + self.snippet = snippet + self.text = text + + def set_text(self, text: str): + self.text = text + +def scrape_text(html: str, max_words: int = None) -> str: + soup = BeautifulSoup(html, "html.parser") + for selector in [ + "main", + ".main-content-wrapper", + ".main-content", + ".emt-container-inner", + ".content-wrapper", + "#content", + "#mainContent", + ]: + select = soup.select_one(selector) + if select: + soup = select + break + # Zdnet + for remove in [".c-globalDisclosure"]: + select = soup.select_one(remove) + if select: + select.extract() + clean_text = "" + for paragraph in soup.select("p, h1, h2, h3, h4, h5, h6"): + text = paragraph.get_text() + for line in text.splitlines(): + words = [] + for word in line.replace("\t", " ").split(" "): + if word: + words.append(word) + count = len(words) + if not count: + continue + if max_words: + max_words -= count + if max_words <= 0: + break + if clean_text: + clean_text += "\n" + clean_text += " ".join(words) + + return clean_text + +async def fetch_and_scrape(session: ClientSession, url: str, max_words: int = None) -> str: + try: + async with session.get(url) as response: + if response.status == 200: + html = await response.text() + return scrape_text(html, max_words) + except ClientError: + return + +async def search(query: str, max_results: int = 5, max_words: int = 2500, backend: str = "api", add_text: bool = True, timeout: int = 5, region: str = "wt-wt") -> SearchResults: + if not has_requirements: + raise MissingRequirementsError('Install "duckduckgo-search" and "beautifulsoup4" package | pip install -U g4f[search]') + with DDGS() as ddgs: + results = [] + for result in ddgs.text( + query, + region=region, + safesearch="moderate", + timelimit="y", + max_results=max_results, + backend=backend, + ): + results.append(SearchResultEntry( + result["title"], + result["href"], + result["body"] + )) + + if add_text: + requests = [] + async with ClientSession(timeout=ClientTimeout(timeout)) as session: + for entry in results: + requests.append(fetch_and_scrape(session, entry.url, int(max_words / (max_results - 1)))) + texts = await asyncio.gather(*requests) + + formatted_results = [] + used_words = 0 + left_words = max_words + for i, entry in enumerate(results): + if add_text: + entry.text = texts[i] + if left_words: + left_words -= entry.title.count(" ") + 5 + if entry.text: + left_words -= entry.text.count(" ") + else: + left_words -= entry.snippet.count(" ") + if 0 > left_words: + break + used_words = max_words - left_words + formatted_results.append(entry) + + return SearchResults(formatted_results, used_words) + +async def do_search(prompt: str, query: str = None, instructions: str = DEFAULT_INSTRUCTIONS, **kwargs) -> str: + if query is None: + query = prompt + search_results = await search(query, **kwargs) + new_prompt = f""" +{search_results} + +Instruction: {instructions} + +User request: +{prompt} +""" + debug.log(f"Web search: '{query.strip()[:50]}...' {len(search_results.results)} Results {search_results.used_words} Words") + return new_prompt + +def get_search_message(prompt: str, raise_search_exceptions=False, **kwargs) -> str: + try: + return asyncio.run(do_search(prompt, **kwargs)) + except (DuckDuckGoSearchException, MissingRequirementsError) as e: + if raise_search_exceptions: + raise e + debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}") + return prompt \ No newline at end of file