diff --git a/examples/assistant_mcp_sqlite_bot.py b/examples/assistant_mcp_sqlite_bot.py new file mode 100644 index 0000000..6a402ab --- /dev/null +++ b/examples/assistant_mcp_sqlite_bot.py @@ -0,0 +1,99 @@ +"""A sqlite database assistant implemented by assistant""" + +import os +import asyncio +from typing import Optional + +from qwen_agent.agents import Assistant +from qwen_agent.gui import WebUI + +ROOT_RESOURCE = os.path.join(os.path.dirname(__file__), 'resource') + + +def init_agent_service(): + llm_cfg = {'model': 'qwen-max'} + system = ('你扮演一个数据库助手,你具有查询数据库的能力') + tools = [{ + "mcpServers": { + "sqlite" : { + "command": "uvx", + "args": [ + "mcp-server-sqlite", + "--db-path", + "test.db" + ] + } + } + }] + bot = Assistant( + llm=llm_cfg, + name='数据库助手', + description='数据库查询', + system_message=system, + function_list=tools, + ) + + return bot + + +def test(query='数据库里有几张表', file: Optional[str] = os.path.join(ROOT_RESOURCE, 'poem.pdf')): + # Define the agent + bot = init_agent_service() + + # Chat + messages = [] + + if not file: + messages.append({'role': 'user', 'content': query}) + else: + messages.append({'role': 'user', 'content': [{'text': query}, {'file': file}]}) + + for response in bot.run(messages): + print('bot response:', response) + + +def app_tui(): + # Define the agent + bot = init_agent_service() + + # Chat + messages = [] + while True: + # Query example: 数据库里有几张表 + query = input('user question: ') + # File example: resource/poem.pdf + file = input('file url (press enter if no file): ').strip() + if not query: + print('user question cannot be empty!') + continue + if not file: + messages.append({'role': 'user', 'content': query}) + else: + messages.append({'role': 'user', 'content': [{'text': query}, {'file': file}]}) + + response = [] + for response in bot.run(messages): + print('bot response:', response) + messages.extend(response) + + +def app_gui(): + # Define the agent + bot = init_agent_service() + chatbot_config = { + 'prompt.suggestions': [ + '数据库里有几张表', + '创建一个学生表包括学生的姓名、年龄', + '增加一个学生名字叫韩梅梅,今年6岁', + ] + } + WebUI( + bot, + chatbot_config=chatbot_config, + ).run() + + +if __name__ == '__main__': + # test() + # app_tui() + app_gui() diff --git a/qwen_agent/agent.py b/qwen_agent/agent.py index de201b2..b4b95cd 100644 --- a/qwen_agent/agent.py +++ b/qwen_agent/agent.py @@ -8,7 +8,7 @@ from qwen_agent.llm.base import BaseChatModel from qwen_agent.llm.schema import CONTENT, DEFAULT_SYSTEM_MESSAGE, ROLE, SYSTEM, ContentItem, Message from qwen_agent.log import logger -from qwen_agent.tools import TOOL_REGISTRY, BaseTool +from qwen_agent.tools import TOOL_REGISTRY, BaseTool, MCPManager from qwen_agent.utils.utils import has_chinese_messages, merge_generate_cfgs @@ -192,6 +192,13 @@ def _init_tool(self, tool: Union[str, Dict, BaseTool]): if tool_name in self.function_map: logger.warning(f'Repeatedly adding tool {tool_name}, will use the newest tool in function list') self.function_map[tool_name] = tool + elif isinstance(tool, dict) and 'mcpServers' in tool: + tools = MCPManager().initConfig(tool) + for tool in tools: + tool_name = tool.name + if tool_name in self.function_map: + logger.warning(f'Repeatedly adding tool {tool_name}, will use the newest tool in function list') + self.function_map[tool_name] = tool else: if isinstance(tool, dict): tool_name = tool['name'] diff --git a/qwen_agent/tools/__init__.py b/qwen_agent/tools/__init__.py index 32ed4c1..589d264 100644 --- a/qwen_agent/tools/__init__.py +++ b/qwen_agent/tools/__init__.py @@ -10,6 +10,7 @@ from .simple_doc_parser import SimpleDocParser from .storage import Storage from .web_extractor import WebExtractor +from .mcp_manager import MCPManager __all__ = [ 'BaseTool', @@ -28,4 +29,5 @@ 'FrontPageSearch', 'ExtractDocVocabulary', 'PythonExecutor', + 'MCPManager' ] diff --git a/qwen_agent/tools/mcp_manager.py b/qwen_agent/tools/mcp_manager.py new file mode 100644 index 0000000..2a2eb71 --- /dev/null +++ b/qwen_agent/tools/mcp_manager.py @@ -0,0 +1,187 @@ +import json +import urllib.parse +import asyncio +import threading +from typing import Optional, Union, List, Dict +from contextlib import AsyncExitStack + +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client + +from qwen_agent.log import logger +from qwen_agent.tools.base import BaseTool, register_tool + +from dotenv import load_dotenv + +class MCPManager: + _instance = None # Private class variable to store the unique instance + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super(MCPManager, cls).__new__(cls, *args, **kwargs) + cls._instance.__init__() + return cls._instance + + def __init__(self): + if not hasattr(self, 'clients'): + """Set a new event loop in a separate thread""" + load_dotenv() # Load environment variables from .env file + self.clients: dict = {} + self.exit_stack = AsyncExitStack() + self.loop = asyncio.new_event_loop() + self.loop_thread = threading.Thread(target=self.start_loop, daemon=True) + self.loop_thread.start() + + def start_loop(self): + asyncio.set_event_loop(self.loop) + self.loop.run_forever() + + def is_valid_mcp_servers(self, config: dict): + """Example of mcp servers configuration: + { + "mcpServers": { + "memory": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-memory"] + }, + "filesystem": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/allowed/files"] + }, + "github": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-github"], + "env": { + "GITHUB_PERSONAL_ACCESS_TOKEN": "" + } + } + } + } + """ + + # Check if the top-level key "mcpServers" exists and its value is a dictionary + if not isinstance(config, dict) or 'mcpServers' not in config or not isinstance(config['mcpServers'], dict): + return False + mcp_servers = config['mcpServers'] + # Check each sub-item under "mcpServers" + for key in mcp_servers: + server = mcp_servers[key] + # Each sub-item must be a dictionary and contain the keys "command" and "args" + if not isinstance(server, dict) or 'command' not in server or 'args' not in server: + return False + # "command" must be a string + if not isinstance(server['command'], str): + return False + # "args" must be a list + if not isinstance(server['args'], list): + return False + # If the "env" key exists, it must be a dictionary + if 'env' in server and not isinstance(server['env'], dict): + return False + return True + + def initConfig(self, config: Dict): + logger.info(f'Initialize from config {config}. ') + if not self.is_valid_mcp_servers(config): + raise ValueError('Config format error') + # Submit coroutine to the event loop and wait for the result + future = asyncio.run_coroutine_threadsafe(self.init_config_async(config), self.loop) + try: + result = future.result() # You can specify a timeout if desired + return result + except Exception as e: + logger.info(f"Error executing function: {e}") + return None + + async def init_config_async(self, config: Dict): + tools : list = [] + mcp_servers = config['mcpServers'] + for server_name in mcp_servers: + client = MCPClient() + server = mcp_servers[server_name] + await client.connection_server(self.exit_stack, server) # Attempt to connect to the server + self.clients[server_name] = client # Add to clients dict after successful connection + for tool in client.tools: + """MCP tool example: + { + "name": "read_query", + "description": "Execute a SELECT query on the SQLite database", + "inputSchema": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "SELECT SQL query to execute" + } + }, + "required": ["query"] + } + """ + parameters = tool.inputSchema + # The required field in inputSchema may be empty and needs to be initialized. + if 'required' not in parameters: + parameters['required'] = [] + register_name = server_name + "-" + tool.name + agent_tool = self.create_tool_class(register_name, server_name, tool.name, tool.description, parameters) + tools.append(agent_tool) + return tools + + def create_tool_class(self, register_name, server_name, tool_name, tool_desc, tool_parameters): + @register_tool(register_name) + class ToolClass(BaseTool): + description = tool_desc + parameters = tool_parameters + + def call(self, params: Union[str, dict], **kwargs) -> str: + tool_args = json.loads(params) + # Submit coroutine to the event loop and wait for the result + manager = MCPManager() + client = manager.clients[server_name] + future = asyncio.run_coroutine_threadsafe(client.execute_function(tool_name, tool_args), manager.loop) + try: + result = future.result() + return result + except Exception as e: + logger.info(f"Error executing function: {e}") + return None + return "Function executed" + + ToolClass.__name__ = f"{register_name}_Class" + return ToolClass() + + async def clearup(self): + await self.exit_stack.aclose() + + +class MCPClient: + def __init__(self): + # Initialize session and client objects + self.session: Optional[ClientSession] = None + self.tools : list = None + + async def connection_server(self, exit_stack, mcp_server): + """Connect to an MCP server and retrieve the available tools.""" + try: + server_params = StdioServerParameters( + command = mcp_server["command"], + args = mcp_server["args"], + env = mcp_server.get("env", None) + ) + stdio_transport = await exit_stack.enter_async_context(stdio_client(server_params)) + self.stdio, self.write = stdio_transport + self.session = await exit_stack.enter_async_context(ClientSession(self.stdio, self.write)) + + await self.session.initialize() + + list_tools = await self.session.list_tools() + self.tools = list_tools.tools + except Exception as e: + logger.info(f"Failed to connect to server: {e}") + + async def execute_function(self, tool_name, tool_args: dict): + response = await self.session.call_tool(tool_name, tool_args) + for content in response.content: + if content.type == 'text': + return content.text + else: + return "execute error" diff --git a/setup.py b/setup.py index 5dae7fe..dc5960c 100644 --- a/setup.py +++ b/setup.py @@ -52,6 +52,7 @@ def read_description() -> str: 'pydantic>=2.3.0', 'requests', 'tiktoken', + 'mcp', ], extras_require={ # Extra dependencies for RAG: