Skip to content

Commit

Permalink
修复 Langchain tool 的单测失败问题 (#663)
Browse files Browse the repository at this point in the history
* 修复 Langchain tool 的单测失败问题

* 更新自带 Tool 的定义

* 更新单测
  • Loading branch information
Dobiichi-Origami authored Jul 12, 2024
1 parent dc71eca commit c71aaec
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 24 deletions.
2 changes: 1 addition & 1 deletion python/qianfan/common/tool/baidu_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
self.client = client
self.with_reference = with_reference

def run(self, parameters: Dict[str, str] = {}) -> Union[str, Dict[str, Any]]:
def run(self, **parameters: Any) -> Union[str, Dict[str, Any]]:
"""
Run the tool and get the summary and reference of the search query
"""
Expand Down
11 changes: 6 additions & 5 deletions python/qianfan/common/tool/base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class BaseTool:
describing the parameters needed when invoking the tool to the model.
"""

def run(self, parameters: Any = None) -> Any:
def run(self, **parameters: Any) -> Any:
"""
Runs the tool.
:param parameters: The input parameters for the tool
Expand All @@ -191,7 +191,7 @@ def to_function_call_schema(self) -> Dict[str, Any]:
@staticmethod
def from_langchain_tool(langchain_tool: Any) -> "BaseTool":
assert_package_installed("langchain")
from langchain.tools.base import BaseTool as LangchainBaseTool
from langchain_core.tools import BaseTool as LangchainBaseTool

if not isinstance(langchain_tool, LangchainBaseTool):
raise TypeError(
Expand All @@ -209,8 +209,8 @@ class Tool(BaseTool):
description = langchain_tool.description
parameters = root_properties

def run(self, parameters: Any = None) -> Any:
return langchain_tool._run(**parameters)
def run(self, **parameters: Any) -> Any:
return langchain_tool.run(**parameters)

return Tool()

Expand All @@ -229,6 +229,7 @@ class Tool(LangchainBaseTool):
args_schema: Type[PydanticV1BaseModel] = tool_schema

def _run(self, **kwargs: Any) -> Any:
return tool_run(kwargs)
print(kwargs)
return tool_run(**kwargs)

return Tool()
4 changes: 2 additions & 2 deletions python/qianfan/common/tool/duckduckgo_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
duck duck go search tool
"""

from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

from qianfan.common.tool.base_tool import BaseTool, ToolParameter
from qianfan.utils.utils import assert_package_installed
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
self.timelimit: Optional[str] = timelimit
self.max_results = max_results

def run(self, parameters: Dict[str, str] = {}) -> List[Dict[str, str]]:
def run(self, **parameters: Any) -> List[Dict[str, str]]:
from duckduckgo_search import DDGS

with DDGS() as client:
Expand Down
4 changes: 2 additions & 2 deletions python/qianfan/common/tool/wikipedia_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
wikipedia tool
"""

from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

from qianfan.common.tool.base_tool import BaseTool, ToolParameter
from qianfan.utils.utils import assert_package_installed
Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(
self.wiki_max_length = wiki_max_length
self.result_max_length = result_max_length

def run(self, parameters: Dict[str, str] = {}) -> List[Dict[str, str]]:
def run(self, **parameters: Any) -> List[Dict[str, str]]:
import wikipedia

query = parameters["search_keyword"][: self.WIKIPEDIA_MAX_QUERY_LENGTH]
Expand Down
28 changes: 14 additions & 14 deletions python/qianfan/tests/tool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,15 @@ def _run(self, a: int, b: int, prefix: Optional[str] = None):
assert tool.parameters[2] == ToolParameter(
name="prefix", type="string", description="prefix description", required=False
)
assert tool.run({"a": 1, "b": 2}) == 3
assert tool.run({"a": 1, "b": 2, "prefix": "result: "}) == "result: 3"
assert tool.run(tool_input={"a": 1, "b": 2}) == 3
assert tool.run(tool_input={"a": 1, "b": 2, "prefix": "result: "}) == "result: 3"


def test_tool_from_langchain_func_tool():
if not check_package_installed("langchain"):
return

from langchain.tools.base import Tool as LangchainTool
from langchain_core.tools import StructuredTool as LangchainTool

from qianfan.utils.pydantic import BaseModel, Field

Expand Down Expand Up @@ -174,7 +174,7 @@ class FuncToolSchema(BaseModel):
assert tool.parameters[1] == ToolParameter(
name="b", type="string", description="b description", required=True
)
assert tool.run({"a": "1", "b": "2"}) == "hello 1 2"
assert tool.run(tool_input={"a": "1", "b": "2"}) == "hello 1 2"


def test_tool_from_langchain_decorator_tool():
Expand All @@ -201,7 +201,7 @@ def hello_tool(
assert len(tool.parameters) == 2
assert tool.parameters[0] == ToolParameter(name="a", type="string", required=True)
assert tool.parameters[1] == ToolParameter(name="b", type="string", required=True)
assert tool.run({"a": "1", "b": "2"}) == "hello 1 2"
assert tool.run(tool_input={"a": "1", "b": "2"}) == "hello 1 2"


def test_tool_to_langchain_tool():
Expand All @@ -220,7 +220,7 @@ class TestTool(BaseTool):
)
]

def run(self, parameters=None):
def run(self, **parameters):
return parameters["test_param"]

tool = TestTool()
Expand All @@ -243,7 +243,7 @@ def run(self, parameters=None):
"required": ["test_param"],
}

assert tool.run({"test_param": "value"}) == "value"
assert tool.run(**{"test_param": "value"}) == "value"
assert langchain_tool.invoke({"test_param": "value"}) == "value"


Expand Down Expand Up @@ -307,7 +307,7 @@ class ComplexTestTool(BaseTool):
),
]

def run(self, parameters=None):
def run(self, **parameters):
return (
parameters["required_integer"] if parameters["required_boolean"] else 0
)
Expand Down Expand Up @@ -395,8 +395,8 @@ def run(self, parameters=None):
"required_boolean": False,
"required_object": {"required_nested_string": "required_nested_string"},
}
assert tool.run(args_one) == 1
assert tool.run(args_zero) == 0
assert tool.run(**args_one) == 1
assert tool.run(**args_zero) == 0
assert langchain_tool.invoke(args_one) == 1
assert langchain_tool.invoke(args_zero) == 0

Expand Down Expand Up @@ -487,7 +487,7 @@ def test_nested_parameter_to_json_schema():

def test_baidu_search_tool():
tool = BaiduSearchTool(with_reference=True)
res = tool.run({"search_query": "上海天气"})
res = tool.run(**{"search_query": "上海天气"})
assert (
res["summary"]
== "**上海今天天气是晴转阴,气温在-4℃到1℃之间,风向无持续风向,"
Expand All @@ -507,20 +507,20 @@ def test_baidu_search_tool():
]

tool = BaiduSearchTool()
res = tool.run({"search_query": "上海天气"})
res = tool.run(**{"search_query": "上海天气"})
assert (
res
== "**上海今天天气是晴转阴,气温在-4℃到1℃之间,风向无持续风向,"
"风力小于3级,空气质量优**^[1]^。"
)

tool = BaiduSearchTool(with_reference=True)
res = tool.run({"search_query": "no_search"})
res = tool.run(**{"search_query": "no_search"})
assert res["reference"] == []

client = qianfan.Completion()
tool = BaiduSearchTool(client=client, with_reference=True)
res = tool.run({"search_query": "上海天气"})
res = tool.run(**{"search_query": "上海天气"})
assert (
res["summary"]
== "**上海今天天气是晴转阴,气温在-4℃到1℃之间,风向无持续风向,"
Expand Down

0 comments on commit c71aaec

Please sign in to comment.