Skip to content

Commit

Permalink
fix tools description rendering
Browse files Browse the repository at this point in the history
  • Loading branch information
boczekbartek committed Jan 14, 2025
1 parent 0a6cb00 commit c6c2489
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/rai/rai/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import rclpy.task
import sensor_msgs.msg
from langchain.tools import BaseTool
from langchain.tools.render import render_text_description_and_args
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langgraph.graph.graph import CompiledGraph
from rclpy.action.server import ActionServer, GoalResponse, ServerGoalHandle
Expand All @@ -34,6 +33,7 @@

from rai.agents.state_based import Report, State, create_state_based_agent
from rai.messages import HumanMultimodalMessage
from rai.tools.render import render_text_description_with_args_from_tools
from rai.tools.ros.utils import convert_ros_img_to_base64
from rai.utils.model_initialization import get_llm_model, get_tracing_callbacks
from rai.utils.ros_async import get_future_result
Expand Down Expand Up @@ -109,7 +109,7 @@ def append_tools_text_description_to_prompt(prompt: str, tools: List[BaseTool])
Use the tooling provided to gather information about the environment:
{render_text_description_and_args(tools)}
{render_text_description_with_args_from_tools(tools)}
"""


Expand Down
31 changes: 31 additions & 0 deletions src/rai/rai/tools/render.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from copy import deepcopy

from langchain.tools import BaseTool
from langchain_core.tools import (
create_schema_from_function,
render_text_description_and_args,
)


def filter_out_injected_tool_args(tool: BaseTool) -> BaseTool:
"""
Create a copy of tool with args filtered out.
The main purpose of this function is to make `langchain_core.tools.render_text_description_with_args`
usable for `@tool`s with `IntectedToolArg`s.
It doens't work without this modification, because implementation of StructuredTool
includes arguments annotated as `InjectedToolArg` in tool's `args_schema`,
which can cause `pydantic.errors.PydanticInvalidForJsonSchema` for arguments that
are not parsalbe by pydantic (like `rclpy.node.Node`)
"""
new_tool = deepcopy(tool)
new_tool.args_schema = create_schema_from_function(
tool.name, tool._run, include_injected=False
)
return new_tool


def render_text_description_with_args_from_tools(tools: list[BaseTool]) -> str:
tools = [filter_out_injected_tool_args(t) for t in tools]
return render_text_description_and_args(tools)
3 changes: 2 additions & 1 deletion tests/core/test_ros2_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from std_msgs.msg import String

from rai.agents.state_based import create_state_based_agent
from rai.tools.render import render_text_description_with_args_from_tools
from rai.tools.ros.native import Ros2PubMessageTool


Expand Down Expand Up @@ -83,7 +84,7 @@ def test_ros2_pub_message_tool_llm(
t = threading.Thread(target=executor.spin)

system = SystemMessage(
"You are a ros2 agent that can run tools: {render_text_description_and_args(tools)}"
f"You are a ros2 agent that can run tools: {render_text_description_with_args_from_tools(tools)}"
)
query = HumanMessage(
f"Publish a std_msgs/msg/String '{test_message}' to the topic '{topic}'"
Expand Down

0 comments on commit c6c2489

Please sign in to comment.