Skip to content

Commit

Permalink
chore: remove langchain _run method wrapper (#395)
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejmajek authored Jan 29, 2025
1 parent 6a19ab2 commit df4a8ae
Show file tree
Hide file tree
Showing 12 changed files with 120 additions and 112 deletions.
4 changes: 2 additions & 2 deletions src/rai/rai/tools/ros/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _run(
x: float,
y: float,
z: float,
task: Literal["grab", "place"],
task: Literal["grab", "drop"],
) -> str:
pose_stamped = PoseStamped()
pose_stamped.header.frame_id = self.manipulator_frame
Expand All @@ -96,7 +96,7 @@ def _run(
orientation=self.quaternion,
)

if task == "place":
if task == "drop":
pose_stamped.pose.position.z += self.additional_height

pose_stamped.pose.position.x += self.calibration_x
Expand Down
2 changes: 1 addition & 1 deletion src/rai/rai/tools/ros/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from rclpy.impl.rcutils_logger import RcutilsLogger
from rosidl_runtime_py.utilities import get_namespaced_type

from .utils import convert_ros_img_to_base64, import_message_from_str
from rai.tools.ros.utils import convert_ros_img_to_base64, import_message_from_str


# --------------------- Inputs ---------------------
Expand Down
6 changes: 3 additions & 3 deletions src/rai/rai/tools/ros/native_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from rclpy.action.client import ActionClient
from rosidl_runtime_py import message_to_ordereddict

from .native import Ros2BaseInput, Ros2BaseTool
from .utils import get_transform
from rai.tools.ros.native import Ros2BaseInput, Ros2BaseTool
from rai.tools.ros.utils import get_transform


# --------------------- Inputs ---------------------
Expand Down Expand Up @@ -197,7 +197,7 @@ class GetTransformTool(Ros2BaseActionTool):

args_schema: Type[GetTransformInput] = GetTransformInput

def _run(self, target_frame="map", source_frame="body_link") -> dict:
def _run(self, target_frame: str = "map", source_frame: str = "body_link") -> dict:
return message_to_ordereddict(
get_transform(self.node, target_frame, source_frame)
)
3 changes: 1 addition & 2 deletions src/rai/rai/tools/ros/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@
from tf_transformations import euler_from_quaternion

from rai.tools.ros.deprecated import SingleMessageGrabber
from rai.tools.ros.native import TopicInput
from rai.tools.utils import TF2TransformFetcher

from .native import TopicInput

logger = logging.getLogger(__name__)


Expand Down
23 changes: 12 additions & 11 deletions src/rai/rai/tools/ros2/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@
from pydantic import BaseModel, Field

from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage
from rai.tools.utils import wrap_tool_input # type: ignore


class StartROS2ActionToolInput(BaseModel):
action_name: str = Field(..., description="The name of the action to start")
action_type: str = Field(..., description="The type of the action")
args: Dict[str, Any] = Field(..., description="The arguments to pass to the action")
action_args: Dict[str, Any] = Field(
..., description="The arguments to pass to the action"
)


class StartROS2ActionTool(BaseTool):
Expand All @@ -42,15 +43,16 @@ class StartROS2ActionTool(BaseTool):
description: str = "Start a ROS2 action"
args_schema: Type[StartROS2ActionToolInput] = StartROS2ActionToolInput

@wrap_tool_input
def _run(self, tool_input: StartROS2ActionToolInput) -> str:
message = ROS2ARIMessage(payload=tool_input.args)
def _run(
self, action_name: str, action_type: str, action_args: Dict[str, Any]
) -> str:
message = ROS2ARIMessage(payload=action_args)
response = self.connector.start_action(
message,
tool_input.action_name,
action_name,
on_feedback=self.feedback_callback,
on_done=self.on_done_callback,
msg_type=tool_input.action_type,
msg_type=action_type,
)
return "Action started with ID: " + response

Expand All @@ -65,10 +67,9 @@ class CancelROS2ActionTool(BaseTool):
description: str = "Cancel a ROS2 action"
args_schema: Type[CancelROS2ActionToolInput] = CancelROS2ActionToolInput

@wrap_tool_input
def _run(self, tool_input: CancelROS2ActionToolInput) -> str:
self.connector.terminate_action(tool_input.action_id)
return f"Action {tool_input.action_id} cancelled"
def _run(self, action_id: str) -> str:
self.connector.terminate_action(action_id)
return f"Action {action_id} cancelled"


@tool
Expand Down
12 changes: 6 additions & 6 deletions src/rai/rai/tools/ros2/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@
from pydantic import BaseModel, Field

from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage
from rai.tools.utils import wrap_tool_input # type: ignore


class CallROS2ServiceToolInput(BaseModel):
service_name: str = Field(..., description="The service to call")
service_type: str = Field(..., description="The type of the service")
args: Dict[str, Any] = Field(
service_args: Dict[str, Any] = Field(
..., description="The arguments to pass to the service"
)

Expand All @@ -42,11 +41,12 @@ class CallROS2ServiceTool(BaseTool):
description: str = "Call a ROS2 service"
args_schema: Type[CallROS2ServiceToolInput] = CallROS2ServiceToolInput

@wrap_tool_input
def _run(self, tool_input: CallROS2ServiceToolInput) -> str:
message = ROS2ARIMessage(payload=tool_input.args)
def _run(
self, service_name: str, service_type: str, service_args: Dict[str, Any]
) -> str:
message = ROS2ARIMessage(payload=service_args)
response = self.connector.service_call(
message, tool_input.service_name, msg_type=tool_input.service_type
message, service_name, msg_type=service_type
)
return str(
{
Expand Down
48 changes: 15 additions & 33 deletions src/rai/rai/tools/ros2/topics.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from rai.messages.multimodal import MultimodalArtifact
from rai.messages.utils import preprocess_image
from rai.tools.ros2.utils import ros2_message_to_dict
from rai.tools.utils import wrap_tool_input # type: ignore


class PublishROS2MessageToolInput(BaseModel):
Expand All @@ -49,15 +48,12 @@ class PublishROS2MessageTool(BaseTool):
description: str = "Publish a message to a ROS2 topic"
args_schema: Type[PublishROS2MessageToolInput] = PublishROS2MessageToolInput

@wrap_tool_input
def _run(self, tool_input: PublishROS2MessageToolInput) -> str:
def _run(self, topic: str, message: Dict[str, Any], message_type: str) -> str:
ros_message = ROS2ARIMessage(
payload=tool_input.message,
metadata={"topic": tool_input.topic},
)
self.connector.send_message(
ros_message, target=tool_input.topic, msg_type=tool_input.message_type
payload=message,
metadata={"topic": topic},
)
self.connector.send_message(ros_message, target=topic, msg_type=message_type)
return "Message published successfully"


Expand All @@ -71,9 +67,8 @@ class ReceiveROS2MessageTool(BaseTool):
description: str = "Receive a message from a ROS2 topic"
args_schema: Type[ReceiveROS2MessageToolInput] = ReceiveROS2MessageToolInput

@wrap_tool_input
def _run(self, tool_input: ReceiveROS2MessageToolInput) -> str:
message = self.connector.receive_message(tool_input.topic)
def _run(self, topic: str) -> str:
message = self.connector.receive_message(topic)
return str({"payload": message.payload, "metadata": message.metadata})


Expand All @@ -88,9 +83,8 @@ class GetROS2ImageTool(BaseTool):
args_schema: Type[GetROS2ImageToolInput] = GetROS2ImageToolInput
response_format: Literal["content", "content_and_artifact"] = "content_and_artifact"

@wrap_tool_input
def _run(self, tool_input: GetROS2ImageToolInput) -> Tuple[str, MultimodalArtifact]:
message = self.connector.receive_message(tool_input.topic)
def _run(self, topic: str) -> Tuple[str, MultimodalArtifact]:
message = self.connector.receive_message(topic)
msg_type = type(message.payload)
if msg_type == Image:
image = CvBridge().imgmsg_to_cv2( # type: ignore
Expand All @@ -107,20 +101,12 @@ def _run(self, tool_input: GetROS2ImageToolInput) -> Tuple[str, MultimodalArtifa
return "Image received successfully", MultimodalArtifact(images=[preprocess_image(image)]) # type: ignore


class GetROS2TopicsNamesAndTypesToolInput(BaseModel):
pass


class GetROS2TopicsNamesAndTypesTool(BaseTool):
connector: ROS2ARIConnector
name: str = "get_ros2_topics_names_and_types"
description: str = "Get the names and types of all ROS2 topics"
args_schema: Type[GetROS2TopicsNamesAndTypesToolInput] = (
GetROS2TopicsNamesAndTypesToolInput
)

@wrap_tool_input
def _run(self, tool_input: GetROS2TopicsNamesAndTypesToolInput) -> str:
def _run(self) -> str:
topics_and_types = self.connector.get_topics_names_and_types()
response = [
stringify_dict({"topic": topic, "type": type})
Expand All @@ -143,12 +129,9 @@ class GetROS2MessageInterfaceTool(BaseTool):
GetROS2MessageInterfaceToolInput
)

@wrap_tool_input
def _run(self, tool_input: GetROS2MessageInterfaceToolInput) -> str:
def _run(self, msg_type: str) -> str:
"""Show ros2 message interface in json format."""
msg_cls: Type[object] = rosidl_runtime_py.utilities.get_interface(
tool_input.msg_type
)
msg_cls: Type[object] = rosidl_runtime_py.utilities.get_interface(msg_type)
try:
msg_dict = ros2_message_to_dict(msg_cls()) # type: ignore
return json.dumps(msg_dict)
Expand Down Expand Up @@ -176,11 +159,10 @@ class GetROS2TransformTool(BaseTool):
description: str = "Get the transform between two frames"
args_schema: Type[GetROS2TransformToolInput] = GetROS2TransformToolInput

@wrap_tool_input
def _run(self, tool_input: GetROS2TransformToolInput) -> str:
def _run(self, target_frame: str, source_frame: str, timeout_sec: float) -> str:
transform = self.connector.get_transform(
target_frame=tool_input.target_frame,
source_frame=tool_input.source_frame,
timeout_sec=tool_input.timeout_sec,
target_frame=target_frame,
source_frame=source_frame,
timeout_sec=timeout_sec,
)
return stringify_dict(ros2_message_to_dict(transform))
9 changes: 0 additions & 9 deletions src/rai/rai/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import base64
import logging
import subprocess
from functools import wraps
from typing import Any, Callable, Dict, List, Literal, Sequence, Union, cast

import cv2
Expand Down Expand Up @@ -44,14 +43,6 @@
from rai.messages import ToolMultimodalMessage


def wrap_tool_input(func): # type: ignore
@wraps(func) # type: ignore
def wrapped(self: BaseTool, **kwargs): # type: ignore
return func(self, self.args_schema(**kwargs)) # type: ignore

return wrapped # type: ignore


# Copied from https://github.com/ros2/rclpy/blob/jazzy/rclpy/rclpy/wait_for_message.py, to support humble
def wait_for_message(
msg_type,
Expand Down
2 changes: 1 addition & 1 deletion tests/tools/ros2/test_action_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_action_call_tool(ros_setup: None, request: pytest.FixtureRequest) -> No
response = tool._run( # type: ignore
action_name=action_name,
action_type="nav2_msgs/action/NavigateToPose",
args={},
action_args={},
)
assert "Action started with ID:" in response

Expand Down
2 changes: 1 addition & 1 deletion tests/tools/ros2/test_service_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_service_call_tool(ros_setup: None, request: pytest.FixtureRequest) -> N
response = tool._run( # type: ignore
service_name=service_name,
service_type="std_srvs/srv/SetBool",
args={},
service_args={},
)
assert "Test service called" in response
assert "success=True" in response
Expand Down
78 changes: 78 additions & 0 deletions tests/tools/test_tool_input_args_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib.util
import inspect
from pathlib import Path

import pytest
from langchain_core.tools import BaseTool


def get_all_tool_classes() -> set[BaseTool]:
"""Recursively find all classes that inherit from pydantic.BaseModel in src/rai/rai/tools"""
tools = []
tools_path = Path("src/rai/rai/tools")

# Recursively find all .py files
for py_file in tools_path.rglob("*.py"):
if py_file.name.startswith("_"): # Skip __init__.py and similar
continue

try:
# Manual module loading since files aren't in __init__
module_name = py_file.stem
spec = importlib.util.spec_from_file_location(module_name, py_file)
if spec is None or spec.loader is None:
continue

module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

for name, obj in inspect.getmembers(module):
if (
inspect.isclass(obj)
and issubclass(obj, BaseTool)
and obj != BaseTool
):
tools.append(obj)
except Exception as e:
print(f"Failed to process {py_file}: {e}")

return set(tools)


@pytest.mark.parametrize("tool", get_all_tool_classes())
def test_tool_input_args_compatibility(tool: BaseTool):
tool_run_annotations = tool._run.__annotations__
if "return" in tool_run_annotations:
tool_run_annotations.pop("return")
if "args" in tool_run_annotations and "kwargs" in tool_run_annotations:
print(
f"Tool {tool} has *args or **kwargs, the _run method is most likely still an abstractmethod"
)
pytest.xfail(
reason="Tool has *args or **kwargs, the _run method is most likely still an abstractmethod"
)
if "args_schema" not in tool.__annotations__:
print(f"Tool {tool} has no args_schema")
pytest.xfail(reason="Tool has no args_schema")

if len(tool.__annotations__["args_schema"].__args__) != 1:
raise NotImplementedError(f"Tool {tool} has ambiguous args_schema")

tool_input_annotations = (
tool.__annotations__["args_schema"].__args__[0].__annotations__
)
assert tool_run_annotations == tool_input_annotations
Loading

0 comments on commit df4a8ae

Please sign in to comment.