From 9376861ededd0fd96c301bc1a83b9035470af396 Mon Sep 17 00:00:00 2001 From: Maciej Majek <46171033+maciejmajek@users.noreply.github.com> Date: Tue, 28 Jan 2025 15:10:19 +0100 Subject: [PATCH] feat: migrating ros2 tools (#385) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Bartłomiej Boczek --- src/rai/rai/communication/ros2/connectors.py | 50 ++++++++++- src/rai/rai/tools/ros2/__init__.py | 14 ++- src/rai/rai/tools/ros2/topics.py | 94 ++++++++++++++++++-- src/rai/rai/tools/ros2/utils.py | 37 ++++++++ tests/communication/ros2/helpers.py | 24 +++++ tests/tools/ros2/test_topic_tools.py | 53 ++++++++++- tests/tools/test_tool_utils.py | 24 +++++ 7 files changed, 286 insertions(+), 10 deletions(-) create mode 100644 src/rai/rai/tools/ros2/utils.py diff --git a/src/rai/rai/communication/ros2/connectors.py b/src/rai/rai/communication/ros2/connectors.py index cac72fdf..cda005aa 100644 --- a/src/rai/rai/communication/ros2/connectors.py +++ b/src/rai/rai/communication/ros2/connectors.py @@ -13,11 +13,18 @@ # limitations under the License. import threading +import time import uuid -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple +import rclpy +import rclpy.executors +import rclpy.node +import rclpy.time +from rclpy.duration import Duration from rclpy.executors import MultiThreadedExecutor from rclpy.node import Node +from tf2_ros import Buffer, LookupException, TransformListener, TransformStamped from rai.communication.ari_connector import ARIConnector, ARIMessage from rai.communication.ros2.api import ROS2ActionAPI, ROS2ServiceAPI, ROS2TopicAPI @@ -43,6 +50,9 @@ def __init__( self._thread = threading.Thread(target=self._executor.spin) self._thread.start() + def get_topics_names_and_types(self) -> List[Tuple[str, List[str]]]: + return self._topic_api.get_topic_names_and_types() + def send_message(self, message: ROS2ARIMessage, target: str): auto_qos_matching = message.metadata.get("auto_qos_matching", True) qos_profile = message.metadata.get("qos_profile", None) @@ -118,6 +128,44 @@ def start_action( raise RuntimeError("Action goal was not accepted") return handle + @staticmethod + def wait_for_transform( + tf_buffer: Buffer, + target_frame: str, + source_frame: str, + timeout_sec: float = 1.0, + ) -> bool: + start_time = time.time() + while time.time() - start_time < timeout_sec: + if tf_buffer.can_transform(target_frame, source_frame, rclpy.time.Time()): + return True + time.sleep(0.1) + return False + + def get_transform( + self, + target_frame: str, + source_frame: str, + timeout_sec: float = 5.0, + ) -> TransformStamped: + tf_buffer = Buffer(node=self._node) + tf_listener = TransformListener(tf_buffer, self._node) + transform_available = self.wait_for_transform( + tf_buffer, target_frame, source_frame, timeout_sec + ) + if not transform_available: + raise LookupException( + f"Could not find transform from {source_frame} to {target_frame} in {timeout_sec} seconds" + ) + transform: TransformStamped = tf_buffer.lookup_transform( + target_frame, + source_frame, + rclpy.time.Time(), + timeout=Duration(seconds=timeout_sec), + ) + tf_listener.unregister() + return transform + def terminate_action(self, action_handle: str): self._actions_api.terminate_goal(action_handle) diff --git a/src/rai/rai/tools/ros2/__init__.py b/src/rai/rai/tools/ros2/__init__.py index 2ea85369..ea7d939f 100644 --- a/src/rai/rai/tools/ros2/__init__.py +++ b/src/rai/rai/tools/ros2/__init__.py @@ -14,13 +14,23 @@ from .actions import CancelROS2ActionTool, StartROS2ActionTool from .services import CallROS2ServiceTool -from .topics import GetImageTool, PublishROS2MessageTool, ReceiveROS2MessageTool +from .topics import ( + GetROS2ImageTool, + GetROS2MessageInterfaceTool, + GetROS2TopicsNamesAndTypesTool, + GetROS2TransformTool, + PublishROS2MessageTool, + ReceiveROS2MessageTool, +) __all__ = [ "StartROS2ActionTool", - "GetImageTool", + "GetROS2ImageTool", "PublishROS2MessageTool", "ReceiveROS2MessageTool", "CallROS2ServiceTool", "CancelROS2ActionTool", + "GetROS2TopicsNamesAndTypesTool", + "GetROS2MessageInterfaceTool", + "GetROS2TransformTool", ] diff --git a/src/rai/rai/tools/ros2/topics.py b/src/rai/rai/tools/ros2/topics.py index 27a05f42..b0c9e171 100644 --- a/src/rai/rai/tools/ros2/topics.py +++ b/src/rai/rai/tools/ros2/topics.py @@ -19,16 +19,21 @@ "This is a ROS2 feature. Make sure ROS2 is installed and sourced." ) +import json from typing import Any, Dict, Literal, Tuple, Type +import rosidl_runtime_py.set_message +import rosidl_runtime_py.utilities from cv_bridge import CvBridge -from langchain_core.tools import BaseTool +from langchain.tools import BaseTool +from langchain_core.utils import stringify_dict from pydantic import BaseModel, Field from sensor_msgs.msg import CompressedImage, Image from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage from rai.messages.multimodal import MultimodalArtifact from rai.messages.utils import preprocess_image +from rai.tools.ros2.utils import ros2_message_to_dict from rai.tools.utils import wrap_tool_input # type: ignore @@ -70,19 +75,19 @@ def _run(self, tool_input: ReceiveROS2MessageToolInput) -> str: return str({"payload": message.payload, "metadata": message.metadata}) -class GetImageToolInput(BaseModel): +class GetROS2ImageToolInput(BaseModel): topic: str = Field(..., description="The topic to receive the image from") -class GetImageTool(BaseTool): +class GetROS2ImageTool(BaseTool): connector: ROS2ARIConnector name: str = "get_ros2_image" description: str = "Get an image from a ROS2 topic" - args_schema: Type[GetImageToolInput] = GetImageToolInput + args_schema: Type[GetROS2ImageToolInput] = GetROS2ImageToolInput response_format: Literal["content", "content_and_artifact"] = "content_and_artifact" @wrap_tool_input - def _run(self, tool_input: GetImageToolInput) -> Tuple[str, MultimodalArtifact]: + def _run(self, tool_input: GetROS2ImageToolInput) -> Tuple[str, MultimodalArtifact]: message = self.connector.receive_message(tool_input.topic) msg_type = type(message.payload) if msg_type == Image: @@ -98,3 +103,82 @@ def _run(self, tool_input: GetImageToolInput) -> Tuple[str, MultimodalArtifact]: f"Unsupported message type: {message.metadata['msg_type']}" ) return "Image received successfully", MultimodalArtifact(images=[preprocess_image(image)]) # type: ignore + + +class GetROS2TopicsNamesAndTypesToolInput(BaseModel): + pass + + +class GetROS2TopicsNamesAndTypesTool(BaseTool): + connector: ROS2ARIConnector + name: str = "get_ros2_topics_names_and_types" + description: str = "Get the names and types of all ROS2 topics" + args_schema: Type[GetROS2TopicsNamesAndTypesToolInput] = ( + GetROS2TopicsNamesAndTypesToolInput + ) + + @wrap_tool_input + def _run(self, tool_input: GetROS2TopicsNamesAndTypesToolInput) -> str: + topics_and_types = self.connector.get_topics_names_and_types() + response = [ + stringify_dict({"topic": topic, "type": type}) + for topic, type in topics_and_types + ] + return "\n".join(response) + + +class GetROS2MessageInterfaceToolInput(BaseModel): + msg_type: str = Field( + ..., description="The type of the message e.g. std_msgs/msg/String" + ) + + +class GetROS2MessageInterfaceTool(BaseTool): + connector: ROS2ARIConnector + name: str = "get_ros2_message_interface" + description: str = "Get the interface of a ROS2 message" + args_schema: Type[GetROS2MessageInterfaceToolInput] = ( + GetROS2MessageInterfaceToolInput + ) + + @wrap_tool_input + def _run(self, tool_input: GetROS2MessageInterfaceToolInput) -> str: + """Show ros2 message interface in json format.""" + msg_cls: Type[object] = rosidl_runtime_py.utilities.get_interface( + tool_input.msg_type + ) + try: + msg_dict = ros2_message_to_dict(msg_cls()) # type: ignore + return json.dumps(msg_dict) + except NotImplementedError: + # For action classes that can't be instantiated + goal_dict = ros2_message_to_dict(msg_cls.Goal()) # type: ignore + + result_dict = ros2_message_to_dict(msg_cls.Result()) # type: ignore + + feedback_dict = ros2_message_to_dict(msg_cls.Feedback()) # type: ignore + return json.dumps( + {"goal": goal_dict, "result": result_dict, "feedback": feedback_dict} + ) + + +class GetROS2TransformToolInput(BaseModel): + target_frame: str = Field(..., description="The target frame") + source_frame: str = Field(..., description="The source frame") + timeout_sec: float = Field(default=5.0, description="The timeout in seconds") + + +class GetROS2TransformTool(BaseTool): + connector: ROS2ARIConnector + name: str = "get_ros2_transform" + description: str = "Get the transform between two frames" + args_schema: Type[GetROS2TransformToolInput] = GetROS2TransformToolInput + + @wrap_tool_input + def _run(self, tool_input: GetROS2TransformToolInput) -> str: + transform = self.connector.get_transform( + target_frame=tool_input.target_frame, + source_frame=tool_input.source_frame, + timeout_sec=tool_input.timeout_sec, + ) + return stringify_dict(ros2_message_to_dict(transform)) diff --git a/src/rai/rai/tools/ros2/utils.py b/src/rai/rai/tools/ros2/utils.py new file mode 100644 index 00000000..04679ed6 --- /dev/null +++ b/src/rai/rai/tools/ros2/utils.py @@ -0,0 +1,37 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, OrderedDict + +import rosidl_runtime_py.convert +import rosidl_runtime_py.set_message +import rosidl_runtime_py.utilities + + +def ros2_message_to_dict(message: Any) -> OrderedDict[str, Any]: + """Convert any ROS2 message into a dictionary. + + Args: + message: A ROS2 message instance + + Returns: + A dictionary representation of the message + + Raises: + TypeError: If the input is not a valid ROS2 message + """ + msg_dict: OrderedDict[str, Any] = rosidl_runtime_py.convert.message_to_ordereddict( + message + ) # type: ignore + return msg_dict diff --git a/tests/communication/ros2/helpers.py b/tests/communication/ros2/helpers.py index 9f1b36c2..99e2c696 100644 --- a/tests/communication/ros2/helpers.py +++ b/tests/communication/ros2/helpers.py @@ -28,6 +28,7 @@ from sensor_msgs.msg import Image from std_msgs.msg import String from std_srvs.srv import SetBool +from tf2_ros import TransformBroadcaster, TransformStamped class ServiceServer(Node): @@ -128,6 +129,29 @@ def publish_message(self) -> None: self.publisher.publish(msg) +class TransformPublisher(Node): + def __init__(self, topic: str): + super().__init__("test_transform_publisher") + self.tf_broadcaster = TransformBroadcaster(self) + self.timer = self.create_timer(0.1, self.publish_transform) + self.frame_id = "base_link" + self.child_frame_id = "map" + + def publish_transform(self) -> None: + msg = TransformStamped() + msg.header.stamp = self.get_clock().now().to_msg() # type: ignore + msg.header.frame_id = self.frame_id # type: ignore + msg.child_frame_id = self.child_frame_id # type: ignore + msg.transform.translation.x = 1.0 # type: ignore + msg.transform.translation.y = 2.0 # type: ignore + msg.transform.translation.z = 3.0 # type: ignore + msg.transform.rotation.x = 0.0 # type: ignore + msg.transform.rotation.y = 0.0 # type: ignore + msg.transform.rotation.z = 0.0 # type: ignore + msg.transform.rotation.w = 1.0 # type: ignore + self.tf_broadcaster.sendTransform(msg) + + def multi_threaded_spinner( nodes: List[Node], ) -> Tuple[List[MultiThreadedExecutor], List[threading.Thread]]: diff --git a/tests/tools/ros2/test_topic_tools.py b/tests/tools/ros2/test_topic_tools.py index 8836145b..4d615597 100644 --- a/tests/tools/ros2/test_topic_tools.py +++ b/tests/tools/ros2/test_topic_tools.py @@ -28,11 +28,19 @@ from PIL import Image from rai.communication.ros2.connectors import ROS2ARIConnector -from rai.tools.ros2 import GetImageTool, PublishROS2MessageTool, ReceiveROS2MessageTool +from rai.tools.ros2 import ( + GetROS2ImageTool, + GetROS2MessageInterfaceTool, + GetROS2TopicsNamesAndTypesTool, + GetROS2TransformTool, + PublishROS2MessageTool, + ReceiveROS2MessageTool, +) from tests.communication.ros2.helpers import ( ImagePublisher, MessagePublisher, MessageReceiver, + TransformPublisher, multi_threaded_spinner, ros_setup, shutdown_executors_and_threads, @@ -92,7 +100,7 @@ def test_receive_image_tool(ros_setup: None, request: pytest.FixtureRequest) -> connector = ROS2ARIConnector() publisher = ImagePublisher(topic=topic_name) executors, threads = multi_threaded_spinner([publisher]) - tool = GetImageTool(connector=connector) + tool = GetROS2ImageTool(connector=connector) time.sleep(1) try: _, artifact_dict = tool._run(topic=topic_name) # type: ignore @@ -103,3 +111,44 @@ def test_receive_image_tool(ros_setup: None, request: pytest.FixtureRequest) -> assert image.size == (100, 100) finally: shutdown_executors_and_threads(executors, threads) + + +def test_get_topics_names_and_types_tool( + ros_setup: None, request: pytest.FixtureRequest +) -> None: + connector = ROS2ARIConnector() + tool = GetROS2TopicsNamesAndTypesTool(connector=connector) + response = tool._run() + assert response != "" + + +def test_get_message_interface_tool( + ros_setup: None, request: pytest.FixtureRequest +) -> None: + connector = ROS2ARIConnector() + tool = GetROS2MessageInterfaceTool(connector=connector) + response = tool._run(msg_type="nav2_msgs/action/NavigateToPose") # type: ignore + assert "goal" in response + assert "result" in response + assert "feedback" in response + response = tool._run(msg_type="std_msgs/msg/String") # type: ignore + assert "data" in response + + +def test_get_transform_tool(ros_setup: None, request: pytest.FixtureRequest) -> None: + topic_name = f"{request.node.originalname}_topic" # type: ignore + connector = ROS2ARIConnector() + publisher = TransformPublisher(topic=topic_name) + executors, threads = multi_threaded_spinner([publisher]) + tool = GetROS2TransformTool(connector=connector) + time.sleep(1.0) + try: + response = tool._run( + target_frame=publisher.frame_id, + source_frame=publisher.child_frame_id, + timeout_sec=1.0, + ) # type: ignore + assert "translation" in response + assert "rotation" in response + finally: + shutdown_executors_and_threads(executors, threads) diff --git a/tests/tools/test_tool_utils.py b/tests/tools/test_tool_utils.py index 784763d9..c10b3b75 100644 --- a/tests/tools/test_tool_utils.py +++ b/tests/tools/test_tool_utils.py @@ -15,11 +15,17 @@ import logging from typing import Any, Dict, List, Type +import pytest +from geometry_msgs.msg import Point, TransformStamped from langchain_core.messages import AIMessage, ToolCall from langchain_core.tools import BaseTool +from nav2_msgs.action import NavigateToPose from pydantic import BaseModel, Field +from sensor_msgs.msg import Image +from tf2_msgs.msg import TFMessage from rai.agents.tool_runner import ToolRunner +from rai.tools.ros2.utils import ros2_message_to_dict from rai.tools.utils import wrap_tool_input @@ -57,3 +63,21 @@ def _run(self, tool_input: TestToolInput) -> str: _ = runner.invoke( {"messages": [AIMessage(content="Hello, how are you?", tool_calls=[tool_call])]} ) + + +# TODO(`maciejmajek`): Add custom RAI messages? +@pytest.mark.parametrize( + "message", + [ + Point(), + Image(), + TFMessage(), + TransformStamped(), + NavigateToPose.Goal(), + NavigateToPose.Result(), + NavigateToPose.Feedback(), + ], + ids=lambda x: x.__class__.__name__, +) +def test_ros2_message_to_dict(message): + assert ros2_message_to_dict(message)