From 86019fc5ccd566baf9aac7238d5073ac06c36eb4 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Fri, 24 Jan 2025 16:59:27 +0100 Subject: [PATCH] refactor(test): use ToolRunner to run tool in test_wrap_tool_input --- tests/tools/test_tool_utils.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/tests/tools/test_tool_utils.py b/tests/tools/test_tool_utils.py index 19a598e5..784763d9 100644 --- a/tests/tools/test_tool_utils.py +++ b/tests/tools/test_tool_utils.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from typing import Any, Dict, List, Type +from langchain_core.messages import AIMessage, ToolCall from langchain_core.tools import BaseTool from pydantic import BaseModel, Field +from rai.agents.tool_runner import ToolRunner from rai.tools.utils import wrap_tool_input @@ -43,13 +46,14 @@ def _run(self, tool_input: TestToolInput) -> str: return "done" tool = TestTool() - result = tool.invoke( - { - "a": 1, - "b": "test", - "c": {"a": 1, "b": 2}, - "d": [1, 2, 3], - "e": b"test", - } + logger = logging.getLogger(__name__) + runner = ToolRunner(tools=[tool], logger=logger) + tool_call = ToolCall( + name="test_tool", + args={"a": 1, "b": "test", "c": {"a": 1, "b": 2}, "d": [1, 2, 3], "e": b"test"}, + id="123", + ) + + _ = runner.invoke( + {"messages": [AIMessage(content="Hello, how are you?", tool_calls=[tool_call])]} ) - assert result == "done"