From d7031209ee81c7df7883508f4f6d20e4498d3586 Mon Sep 17 00:00:00 2001 From: sanbuphy Date: Wed, 11 Dec 2024 22:52:57 +0800 Subject: [PATCH] add web search --- .env.example | 1 + .gitignore | 1 + README.md | 5 +- run/demo_agent_metagpt.py | 4 +- test/agents/metagpt/test_WebSearch.py | 13 +++++ .../agents/metagpt_agents/answerBot/action.py | 9 ++- .../agents/metagpt_agents/searcher/action.py | 56 ++++++++++++++----- .../agents/metagpt_agents/utils/agent_llm.py | 2 +- 8 files changed, 68 insertions(+), 23 deletions(-) create mode 100644 test/agents/metagpt/test_WebSearch.py diff --git a/.env.example b/.env.example index f6a4b65..d722665 100644 --- a/.env.example +++ b/.env.example @@ -6,3 +6,4 @@ OPENAI_API_BASE="https://api.siliconflow.cn/v1" OPENAI_API_MODEL='Qwen/Qwen2.5-7B-Instruct' ZHIPUAI_API_KEY= HF_TOKEN= +TAVILY_API_KEY= \ No newline at end of file diff --git a/.gitignore b/.gitignore index 09ebd79..e6b1cb0 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ logs/ build/* tianji.egg-info/* temp/* +temp_datasets/* *.bin *.pyc test/knowledges/langchain/cache/ diff --git a/README.md b/README.md index c05be15..6dbee82 100644 --- a/README.md +++ b/README.md @@ -165,7 +165,7 @@ pip install -e . 为确保项目正常运行,**请在项目内新建`.env`文件,并在其中设置你的API密钥**,你可以根据下列例子写入对应的 key,即可成功运行调用,目前默认使用 [siliconflow](https://cloud.siliconflow.cn/models) 与 [ZhipuAI](https://bigmodel.cn/),你可以获取对应token即可使用。 -当前 prompt demo 使用 ZhipuAI api,rag 与 agent demo 使用 siliconflow api,可以根据实际需要进行切换使用。 +当前 Pormpt demo 使用 ZhipuAI api,rag 与 agent demo 使用 Siliconflow api,你可以填写这两者调用密钥,即可使用 tianji 的全部功能。 ``` OPENAI_API_KEY= @@ -182,8 +182,11 @@ OPENAI_API_BASE= ZHIPUAI_API_KEY= OPENAI_API_MODEL= HF_TOKEN= +TAVILY_API_KEY= ``` +如果你想要结合 Agent 中的网络搜索工具给出更好的回答,你需要填写上述环境变量的 TAVILY_API_KEY 进行搜索请求,你可以在 [TAVILY 官网](https://app.tavily.com/home)获取体验免费密钥(个人免费额度) + ### 运行 以下给出 prompt 以及 agent 的相关应用方式,在运行前请确保你已经新建`.env`文件: diff --git a/run/demo_agent_metagpt.py b/run/demo_agent_metagpt.py index c8662d3..4b65576 100644 --- a/run/demo_agent_metagpt.py +++ b/run/demo_agent_metagpt.py @@ -61,7 +61,7 @@ def initialize_sidebar(scenes, sharedData): container_scene_attribute.write(st.session_state["scene_attr"]) st.button("Clear Chat History", on_click=lambda: on_btn_click(sharedData)) st.checkbox( - "启用网络搜索", value=st.session_state["enable_se"], key="check", on_change=flip + "启用网络搜索(确保填写密钥)", value=st.session_state["enable_se"], key="check", on_change=flip ) @@ -189,7 +189,7 @@ async def main(): # 如果开启已网络搜索助手 agent ,运行 agent if st.session_state["enable_se"] is True: - with st.spinner("SearcherAgent 运行中..."): + with st.spinner("启用搜索引擎,请稍等片刻... 如有报错,请检查密钥是否填写正确"): await role_search.run(str(sharedData.message_list_for_agent)) sa_res1 = "生成的额外查询:" + str(sharedData.extra_query) diff --git a/test/agents/metagpt/test_WebSearch.py b/test/agents/metagpt/test_WebSearch.py new file mode 100644 index 0000000..5b7bee8 --- /dev/null +++ b/test/agents/metagpt/test_WebSearch.py @@ -0,0 +1,13 @@ +""" +Tavily AI is the leading search engine optimized for LLMs +https://app.tavily.com/ +""" +import os +from tavily import TavilyClient + +tavily_client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY")) + +response = tavily_client.search("What is the weather in Shanghai?",max_results=10) + +for url in response['results']: + print(url['url']) diff --git a/tianji/agents/metagpt_agents/answerBot/action.py b/tianji/agents/metagpt_agents/answerBot/action.py index 17659c4..0c7c108 100644 --- a/tianji/agents/metagpt_agents/answerBot/action.py +++ b/tianji/agents/metagpt_agents/answerBot/action.py @@ -4,14 +4,12 @@ from metagpt.actions import Action from tianji.agents.metagpt_agents.utils.json_from import SharedDataSingleton -from tianji.agents.metagpt_agents.utils.agent_llm import ZhipuApi as LLMApi +from tianji.agents.metagpt_agents.utils.agent_llm import OpenaiApi as LLMApi from tianji.agents.metagpt_agents.utils.helper_func import extract_single_type_attributes_and_examples, extract_attribute_descriptions, load_json - +from metagpt.logs import logger """ 回答助手 agent 所对应的 action。 """ - - class AnswerQuestion(Action): PROMPT_TEMPLATE: str = """ #Role: @@ -56,6 +54,7 @@ async def run(self, instruction: str): if "filtered_content" in item: filtered_dict[index] = item["filtered_content"] + logger.info("AnswerQuestion 最后的回复 agent :scene_attributes scene_attributes_description") prompt = self.PROMPT_TEMPLATE.format( scene=scene, scene_attributes=scene_attributes, @@ -66,5 +65,5 @@ async def run(self, instruction: str): else "", ) - rsp = await LLMApi()._aask(prompt=prompt, temperature=1.00) + rsp = await LLMApi()._aask(prompt=prompt, temperature=0.7) return rsp diff --git a/tianji/agents/metagpt_agents/searcher/action.py b/tianji/agents/metagpt_agents/searcher/action.py index e92e0c0..d9f9306 100644 --- a/tianji/agents/metagpt_agents/searcher/action.py +++ b/tianji/agents/metagpt_agents/searcher/action.py @@ -19,7 +19,7 @@ import requests from bs4 import BeautifulSoup import re - +from tavily import TavilyClient """ 网络搜索助手 agent 所对应的 action。 """ @@ -85,6 +85,7 @@ async def run(self, instruction: str): raise Exception("Searcher agent failed to response") ddgs = DDGS() +tavily_client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY")) class WebSearch(Action): name: str = "WebSearch" @@ -97,7 +98,7 @@ def search(query): max_retry = 5 for attempt in range(max_retry): try: - response = _call_ddgs(query) + response = _call_tavily(query) result = _parse_response(response) return result except Exception as e: @@ -105,8 +106,18 @@ def search(query): raise Exception( "Failed to get search results from DuckDuckGo after retries." ) - + def _call_tavily(query: str, **kwargs) -> dict: + try: + logger.info(f"_call_tavily 正在搜索{query},kwargs为{kwargs}") + response = tavily_client.search(query, max_results=5) + return response + except Exception as e: + raise Exception(f"_call_tavily 搜索{query}出错: {str(e)}") + def _call_ddgs(query: str, **kwargs) -> dict: + """ + TODO ddgs 容易触发202限制,等到后续优化 + """ max_retry = 5 for attempt in range(max_retry): try: @@ -138,14 +149,30 @@ def _parse_response(response: dict) -> dict: raw_results = [] filtered_results = {} count = 0 - for item in response: - raw_results.append( - ( - item["href"], - item["description"] if "description" in item else item["body"], - item["title"], + + # 判断是否为 tavily 搜索引擎的结果 + if isinstance(response, dict) and 'results' in response: + # tavily 搜索引擎结果解析 + for item in response['results']: + raw_results.append( + ( + item['url'], + item['content'], + item['title'] + ) ) - ) + else: + # ddgs 搜索引擎结果解析 + for item in response: + raw_results.append( + ( + item["href"], + item["description"] if "description" in item else item["body"], + item["title"], + ) + ) + + # 过滤和格式化结果 for url, snippet, title in raw_results: if all( domain not in url @@ -160,7 +187,6 @@ def _parse_response(response: dict) -> dict: if count >= 20: # 确保最多返回20个网页的内容,可自行根据大模型的 context length 更换合适的参数。 break return filtered_results - logger.info(f"开始搜索{queries}") with ThreadPoolExecutor() as executor: future_to_query = {executor.submit(search, q): q for q in queries} @@ -226,7 +252,7 @@ async def run(self, instruction: str): for attempt in range(max_retry): try: rsp = await LLMApi()._aask(prompt=prompt, temperature=1.00) - logger.info("机器人分析需求:\n" + rsp) + logger.info("机器人 SelectResult 分析需求:\n" + rsp) rsp = ( rsp.replace("```list", "") .replace("```", "") @@ -281,12 +307,13 @@ def fetch(url: str) -> Tuple[bool, str]: else: if web_success: sharedData.search_results[select_id]["content"] = web_content[ - :4096 + :1024 ] return "" class FilterSelectedResult(Action): + # 该处最好用长上下文的模型 PROMPT_TEMPLATE: str = """ #Role: - 数据抽取小助手。 @@ -316,7 +343,7 @@ async def ask(result, extra_query): search_results=result, extra_query=extra_query ) rsp = await LLMApi()._aask(prompt=prompt, temperature=1.00) - logger.info("机器人分析需求:\n" + rsp) + logger.info("机器人 FilterSelectedResult 分析需求:\n" + rsp) return rsp def run_ask(result, extra_query): @@ -342,6 +369,7 @@ def run_ask(result, extra_query): try: result = future.result() except Exception as exc: + logger.error(f"FilterSelectedResult 提取{select_id}出错: {str(exc)}") pass else: sharedData.search_results[select_id]["filtered_content"] = result diff --git a/tianji/agents/metagpt_agents/utils/agent_llm.py b/tianji/agents/metagpt_agents/utils/agent_llm.py index e5edf45..746e48e 100644 --- a/tianji/agents/metagpt_agents/utils/agent_llm.py +++ b/tianji/agents/metagpt_agents/utils/agent_llm.py @@ -35,7 +35,7 @@ async def _aask( response = self.client.chat.completions.create( model=model, messages=messages, - max_tokens=2048, + max_tokens=4096, top_p=top_p, temperature=temperature, stream=stream,