Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: migrating ros2 tools #385

Merged
merged 8 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion src/rai/rai/communication/ros2/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,18 @@
# limitations under the License.

import threading
import time
import uuid
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple

import rclpy
import rclpy.executors
import rclpy.node
import rclpy.time
from rclpy.duration import Duration
from rclpy.executors import MultiThreadedExecutor
from rclpy.node import Node
from tf2_ros import Buffer, LookupException, TransformListener, TransformStamped

from rai.communication.ari_connector import ARIConnector, ARIMessage
from rai.communication.ros2.api import ROS2ActionAPI, ROS2ServiceAPI, ROS2TopicAPI
Expand All @@ -43,6 +50,9 @@ def __init__(
self._thread = threading.Thread(target=self._executor.spin)
self._thread.start()

def get_topics_names_and_types(self) -> List[Tuple[str, List[str]]]:
return self._topic_api.get_topic_names_and_types()

def send_message(self, message: ROS2ARIMessage, target: str):
auto_qos_matching = message.metadata.get("auto_qos_matching", True)
qos_profile = message.metadata.get("qos_profile", None)
Expand Down Expand Up @@ -118,6 +128,44 @@ def start_action(
raise RuntimeError("Action goal was not accepted")
return handle

@staticmethod
def wait_for_transform(
tf_buffer: Buffer,
target_frame: str,
source_frame: str,
timeout_sec: float = 1.0,
) -> bool:
start_time = time.time()
while time.time() - start_time < timeout_sec:
if tf_buffer.can_transform(target_frame, source_frame, rclpy.time.Time()):
return True
time.sleep(0.1)
return False

def get_transform(
self,
target_frame: str,
source_frame: str,
timeout_sec: float = 5.0,
) -> TransformStamped:
tf_buffer = Buffer(node=self._node)
tf_listener = TransformListener(tf_buffer, self._node)
transform_available = self.wait_for_transform(
tf_buffer, target_frame, source_frame, timeout_sec
)
if not transform_available:
raise LookupException(
f"Could not find transform from {source_frame} to {target_frame} in {timeout_sec} seconds"
)
transform: TransformStamped = tf_buffer.lookup_transform(
target_frame,
source_frame,
rclpy.time.Time(),
timeout=Duration(seconds=timeout_sec),
)
tf_listener.unregister()
return transform

def terminate_action(self, action_handle: str):
self._actions_api.terminate_goal(action_handle)

Expand Down
14 changes: 12 additions & 2 deletions src/rai/rai/tools/ros2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,23 @@

from .actions import CancelROS2ActionTool, StartROS2ActionTool
from .services import CallROS2ServiceTool
from .topics import GetImageTool, PublishROS2MessageTool, ReceiveROS2MessageTool
from .topics import (
GetROS2ImageTool,
GetROS2MessageInterfaceTool,
GetROS2TopicsNamesAndTypesTool,
GetROS2TransformTool,
PublishROS2MessageTool,
ReceiveROS2MessageTool,
)

__all__ = [
"StartROS2ActionTool",
"GetImageTool",
"GetROS2ImageTool",
"PublishROS2MessageTool",
"ReceiveROS2MessageTool",
"CallROS2ServiceTool",
"CancelROS2ActionTool",
"GetROS2TopicsNamesAndTypesTool",
"GetROS2MessageInterfaceTool",
"GetROS2TransformTool",
]
94 changes: 89 additions & 5 deletions src/rai/rai/tools/ros2/topics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,21 @@
"This is a ROS2 feature. Make sure ROS2 is installed and sourced."
)

import json
from typing import Any, Dict, Literal, Tuple, Type

import rosidl_runtime_py.set_message
import rosidl_runtime_py.utilities
from cv_bridge import CvBridge
from langchain_core.tools import BaseTool
from langchain.tools import BaseTool
from langchain_core.utils import stringify_dict
from pydantic import BaseModel, Field
from sensor_msgs.msg import CompressedImage, Image

from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage
from rai.messages.multimodal import MultimodalArtifact
from rai.messages.utils import preprocess_image
from rai.tools.ros2.utils import ros2_message_to_dict
from rai.tools.utils import wrap_tool_input # type: ignore


Expand Down Expand Up @@ -70,19 +75,19 @@ def _run(self, tool_input: ReceiveROS2MessageToolInput) -> str:
return str({"payload": message.payload, "metadata": message.metadata})


class GetImageToolInput(BaseModel):
class GetROS2ImageToolInput(BaseModel):
topic: str = Field(..., description="The topic to receive the image from")


class GetImageTool(BaseTool):
class GetROS2ImageTool(BaseTool):
connector: ROS2ARIConnector
name: str = "get_ros2_image"
description: str = "Get an image from a ROS2 topic"
args_schema: Type[GetImageToolInput] = GetImageToolInput
args_schema: Type[GetROS2ImageToolInput] = GetROS2ImageToolInput
response_format: Literal["content", "content_and_artifact"] = "content_and_artifact"

@wrap_tool_input
def _run(self, tool_input: GetImageToolInput) -> Tuple[str, MultimodalArtifact]:
def _run(self, tool_input: GetROS2ImageToolInput) -> Tuple[str, MultimodalArtifact]:
message = self.connector.receive_message(tool_input.topic)
msg_type = type(message.payload)
if msg_type == Image:
Expand All @@ -98,3 +103,82 @@ def _run(self, tool_input: GetImageToolInput) -> Tuple[str, MultimodalArtifact]:
f"Unsupported message type: {message.metadata['msg_type']}"
)
return "Image received successfully", MultimodalArtifact(images=[preprocess_image(image)]) # type: ignore


class GetROS2TopicsNamesAndTypesToolInput(BaseModel):
pass


class GetROS2TopicsNamesAndTypesTool(BaseTool):
connector: ROS2ARIConnector
name: str = "get_ros2_topics_names_and_types"
description: str = "Get the names and types of all ROS2 topics"
args_schema: Type[GetROS2TopicsNamesAndTypesToolInput] = (
GetROS2TopicsNamesAndTypesToolInput
)

@wrap_tool_input
def _run(self, tool_input: GetROS2TopicsNamesAndTypesToolInput) -> str:
topics_and_types = self.connector.get_topics_names_and_types()
response = [
stringify_dict({"topic": topic, "type": type})
for topic, type in topics_and_types
]
return "\n".join(response)


class GetROS2MessageInterfaceToolInput(BaseModel):
msg_type: str = Field(
..., description="The type of the message e.g. std_msgs/msg/String"
)


class GetROS2MessageInterfaceTool(BaseTool):
connector: ROS2ARIConnector
name: str = "get_ros2_message_interface"
description: str = "Get the interface of a ROS2 message"
args_schema: Type[GetROS2MessageInterfaceToolInput] = (
GetROS2MessageInterfaceToolInput
)

@wrap_tool_input
def _run(self, tool_input: GetROS2MessageInterfaceToolInput) -> str:
"""Show ros2 message interface in json format."""
msg_cls: Type[object] = rosidl_runtime_py.utilities.get_interface(
tool_input.msg_type
)
try:
msg_dict = ros2_message_to_dict(msg_cls()) # type: ignore
return json.dumps(msg_dict)
except NotImplementedError:
# For action classes that can't be instantiated
goal_dict = ros2_message_to_dict(msg_cls.Goal()) # type: ignore

result_dict = ros2_message_to_dict(msg_cls.Result()) # type: ignore

feedback_dict = ros2_message_to_dict(msg_cls.Feedback()) # type: ignore
return json.dumps(
{"goal": goal_dict, "result": result_dict, "feedback": feedback_dict}
)


class GetROS2TransformToolInput(BaseModel):
target_frame: str = Field(..., description="The target frame")
source_frame: str = Field(..., description="The source frame")
timeout_sec: float = Field(default=5.0, description="The timeout in seconds")


class GetROS2TransformTool(BaseTool):
connector: ROS2ARIConnector
name: str = "get_ros2_transform"
description: str = "Get the transform between two frames"
args_schema: Type[GetROS2TransformToolInput] = GetROS2TransformToolInput

@wrap_tool_input
def _run(self, tool_input: GetROS2TransformToolInput) -> str:
transform = self.connector.get_transform(
target_frame=tool_input.target_frame,
source_frame=tool_input.source_frame,
timeout_sec=tool_input.timeout_sec,
)
return stringify_dict(ros2_message_to_dict(transform))
37 changes: 37 additions & 0 deletions src/rai/rai/tools/ros2/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, OrderedDict

import rosidl_runtime_py.convert
import rosidl_runtime_py.set_message
import rosidl_runtime_py.utilities


def ros2_message_to_dict(message: Any) -> OrderedDict[str, Any]:
"""Convert any ROS2 message into a dictionary.

Args:
message: A ROS2 message instance

Returns:
A dictionary representation of the message

Raises:
TypeError: If the input is not a valid ROS2 message
"""
msg_dict: OrderedDict[str, Any] = rosidl_runtime_py.convert.message_to_ordereddict(
message
) # type: ignore
return msg_dict
24 changes: 24 additions & 0 deletions tests/communication/ros2/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from sensor_msgs.msg import Image
from std_msgs.msg import String
from std_srvs.srv import SetBool
from tf2_ros import TransformBroadcaster, TransformStamped


class ServiceServer(Node):
Expand Down Expand Up @@ -128,6 +129,29 @@ def publish_message(self) -> None:
self.publisher.publish(msg)


class TransformPublisher(Node):
def __init__(self, topic: str):
super().__init__("test_transform_publisher")
self.tf_broadcaster = TransformBroadcaster(self)
self.timer = self.create_timer(0.1, self.publish_transform)
self.frame_id = "base_link"
self.child_frame_id = "map"

def publish_transform(self) -> None:
msg = TransformStamped()
msg.header.stamp = self.get_clock().now().to_msg() # type: ignore
msg.header.frame_id = self.frame_id # type: ignore
msg.child_frame_id = self.child_frame_id # type: ignore
msg.transform.translation.x = 1.0 # type: ignore
msg.transform.translation.y = 2.0 # type: ignore
msg.transform.translation.z = 3.0 # type: ignore
msg.transform.rotation.x = 0.0 # type: ignore
msg.transform.rotation.y = 0.0 # type: ignore
msg.transform.rotation.z = 0.0 # type: ignore
msg.transform.rotation.w = 1.0 # type: ignore
self.tf_broadcaster.sendTransform(msg)


def multi_threaded_spinner(
nodes: List[Node],
) -> Tuple[List[MultiThreadedExecutor], List[threading.Thread]]:
Expand Down
53 changes: 51 additions & 2 deletions tests/tools/ros2/test_topic_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,19 @@
from PIL import Image

from rai.communication.ros2.connectors import ROS2ARIConnector
from rai.tools.ros2 import GetImageTool, PublishROS2MessageTool, ReceiveROS2MessageTool
from rai.tools.ros2 import (
GetROS2ImageTool,
GetROS2MessageInterfaceTool,
GetROS2TopicsNamesAndTypesTool,
GetROS2TransformTool,
PublishROS2MessageTool,
ReceiveROS2MessageTool,
)
from tests.communication.ros2.helpers import (
ImagePublisher,
MessagePublisher,
MessageReceiver,
TransformPublisher,
multi_threaded_spinner,
ros_setup,
shutdown_executors_and_threads,
Expand Down Expand Up @@ -92,7 +100,7 @@ def test_receive_image_tool(ros_setup: None, request: pytest.FixtureRequest) ->
connector = ROS2ARIConnector()
publisher = ImagePublisher(topic=topic_name)
executors, threads = multi_threaded_spinner([publisher])
tool = GetImageTool(connector=connector)
tool = GetROS2ImageTool(connector=connector)
time.sleep(1)
try:
_, artifact_dict = tool._run(topic=topic_name) # type: ignore
Expand All @@ -103,3 +111,44 @@ def test_receive_image_tool(ros_setup: None, request: pytest.FixtureRequest) ->
assert image.size == (100, 100)
finally:
shutdown_executors_and_threads(executors, threads)


def test_get_topics_names_and_types_tool(
ros_setup: None, request: pytest.FixtureRequest
) -> None:
connector = ROS2ARIConnector()
tool = GetROS2TopicsNamesAndTypesTool(connector=connector)
response = tool._run()
assert response != ""


def test_get_message_interface_tool(
ros_setup: None, request: pytest.FixtureRequest
) -> None:
connector = ROS2ARIConnector()
tool = GetROS2MessageInterfaceTool(connector=connector)
response = tool._run(msg_type="nav2_msgs/action/NavigateToPose") # type: ignore
assert "goal" in response
assert "result" in response
assert "feedback" in response
response = tool._run(msg_type="std_msgs/msg/String") # type: ignore
assert "data" in response


def test_get_transform_tool(ros_setup: None, request: pytest.FixtureRequest) -> None:
topic_name = f"{request.node.originalname}_topic" # type: ignore
connector = ROS2ARIConnector()
publisher = TransformPublisher(topic=topic_name)
executors, threads = multi_threaded_spinner([publisher])
tool = GetROS2TransformTool(connector=connector)
time.sleep(1.0)
maciejmajek marked this conversation as resolved.
Show resolved Hide resolved
try:
response = tool._run(
target_frame=publisher.frame_id,
source_frame=publisher.child_frame_id,
timeout_sec=1.0,
) # type: ignore
assert "translation" in response
assert "rotation" in response
finally:
shutdown_executors_and_threads(executors, threads)
Loading