From 497308b8ec26b68157591db02d11708a4ce45a99 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Thu, 9 Jan 2025 14:40:42 +0100 Subject: [PATCH 01/14] feat: impl minimal BaseMessage --- src/rai/rai/communication/base_connector.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/rai/rai/communication/base_connector.py b/src/rai/rai/communication/base_connector.py index fe01097f..996de151 100644 --- a/src/rai/rai/communication/base_connector.py +++ b/src/rai/rai/communication/base_connector.py @@ -13,12 +13,20 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Callable +from typing import Any, Callable from uuid import uuid4 class BaseMessage(ABC): - pass + def __init__(self, content: Any): + self.content = content + + def __repr__(self): + return f"{self.__class__.__name__}({self.content=})" + + @property + def msg_type(self) -> Any: + return type(self.content) class BaseConnector(ABC): From a6d337c8d793537db65bf70fb5f86364b727710d Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Thu, 9 Jan 2025 14:44:48 +0100 Subject: [PATCH 02/14] feat: implement ros_connector --- src/rai/rai/communication/ros_connector.py | 109 +++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 src/rai/rai/communication/ros_connector.py diff --git a/src/rai/rai/communication/ros_connector.py b/src/rai/rai/communication/ros_connector.py new file mode 100644 index 00000000..e841596a --- /dev/null +++ b/src/rai/rai/communication/ros_connector.py @@ -0,0 +1,109 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import threading +from typing import Callable, Dict + +import rclpy +from rclpy.executors import MultiThreadedExecutor +from rclpy.node import Node +from rclpy.publisher import Publisher +from rclpy.qos import ( + QoSDurabilityPolicy, + QoSHistoryPolicy, + QoSProfile, + QoSReliabilityPolicy, +) + +from rai.communication.base_connector import BaseConnector, BaseMessage +from rai.tools.ros.utils import import_message_from_str +from rai.tools.utils import wait_for_message + + +class ROS2Connector(BaseConnector): + def __init__(self): + if not rclpy.ok(): + rclpy.init() + + self.node = Node(node_name="rai_ros2_connector") + self.publishers: Dict[str, Publisher] = {} + self.default_qos_profile = QoSProfile( + depth=10, + history=QoSHistoryPolicy.KEEP_LAST, + durability=QoSDurabilityPolicy.SYSTEM_DEFAULT, + reliability=QoSReliabilityPolicy.RELIABLE, + ) + + self.executor = MultiThreadedExecutor() + self.executor.add_node(self.node) + self.executor_thread = threading.Thread(target=self.executor.spin) + self.executor_thread.start() + atexit.register(self.cleanup) + + def send_message(self, msg: BaseMessage, target: str) -> None: + publisher = self.publishers.get(target) + if publisher is None: + self.publishers[target] = self.node.create_publisher( + msg.msg_type, target, qos_profile=self.default_qos_profile + ) + self.publishers[target].publish(msg.content) + + def _validate_and_get_msg_type(self, topic: str): + """ + Validate that the topic exists and return the message type. + """ + topic_names_and_types = self.node.get_topic_names_and_types() + topic_names = [topic for topic, _ in topic_names_and_types] + if topic not in topic_names: + raise ValueError( + f"Topic '{topic}' not found. Available topics: {topic_names}" + ) + return topic_names_and_types[topic_names.index(topic)][1][0] + + def receive_message(self, source: str) -> BaseMessage: + msg_type = self._validate_and_get_msg_type(source) + status, msg = wait_for_message( + import_message_from_str(msg_type), + self.node, + source, + qos_profile=self.default_qos_profile, + ) + + if status: + return BaseMessage(content=msg) + else: + raise ValueError(f"No message found for {source}") + + def start_action( + self, target: str, on_feedback: Callable, on_finish: Callable = lambda _: None + ) -> str: + raise NotImplementedError( + f"{self.__class__.__name__} does not suport starting actions" + ) + + def terminate_action(self, action_handle: str): + raise NotImplementedError( + f"{self.__class__.__name__} does not suport terminating actions" + ) + + def send_and_wait(self, target: str) -> BaseMessage: + raise NotImplementedError( + f"{self.__class__.__name__} does not suport sending messages" + ) + + def cleanup(self): + self.executor.shutdown() + self.executor_thread.join() + self.node.destroy_node() From 035b472277b4e4efcc3ef810b549fc222354d17a Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Thu, 9 Jan 2025 14:50:09 +0100 Subject: [PATCH 03/14] feat: implement http_connector --- src/rai/rai/communication/http_connector.py | 145 ++++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 src/rai/rai/communication/http_connector.py diff --git a/src/rai/rai/communication/http_connector.py b/src/rai/rai/communication/http_connector.py new file mode 100644 index 00000000..8fbda52a --- /dev/null +++ b/src/rai/rai/communication/http_connector.py @@ -0,0 +1,145 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import json +import threading +from datetime import datetime +from http.server import HTTPServer, SimpleHTTPRequestHandler +from typing import Any, Callable, Dict, List + +from rai.communication.base_connector import BaseConnector, BaseMessage + + +class MessageHandler(SimpleHTTPRequestHandler): + """Handler for HTTP requests serving a simple message viewer.""" + + def do_GET(self): + """Serve either the main HTML page or message data.""" + if self.path == "/": + self._serve_html() + elif self.path == "/messages": + self._serve_messages() + + def _serve_html(self): + """Serve the main HTML interface.""" + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write(self.server.get_html_content().encode()) + + def _serve_messages(self): + """Serve message data as JSON.""" + self.send_response(200) + self.send_header("Content-type", "application/json") + self.end_headers() + self.wfile.write(json.dumps(self.server.messages).encode()) + + +class SimpleHTTPServer(HTTPServer): + """HTTP server that maintains a list of messages.""" + + def __init__(self, server_address, RequestHandlerClass): + super().__init__(server_address, RequestHandlerClass) + self.messages: List[Dict[str, Any]] = [] + + def add_message(self, message: str): + """Add a new message with timestamp.""" + self.messages.append( + {"timestamp": datetime.now().isoformat(), "content": message} + ) + + def get_html_content(self) -> str: + """Return the HTML content for the web interface.""" + return """ + + + + Simple Message Viewer + + + +

Messages

+
+ + + + """ + + +class HTTPConnector(BaseConnector): + """Connector that displays messages via a web interface.""" + + def __init__(self, host: str = "localhost", port: int = 8000): + self.host = host + self.port = port + + self.server = SimpleHTTPServer((self.host, self.port), MessageHandler) + self.server_thread = threading.Thread(target=self.server.serve_forever) + self.server_thread.daemon = True + self.server_thread.start() + + atexit.register(self.cleanup) + print(f"Server started at http://{self.host}:{self.port}") + + def send_message(self, msg: BaseMessage, target: str) -> None: + """Add message to the web interface.""" + self.server.add_message(str(msg.content)) + + def receive_message(self, source: str) -> BaseMessage: + raise NotImplementedError( + f"{self.__class__.__name__} does not support receiving messages" + ) + + def start_action( + self, target: str, on_feedback: Callable, on_finish: Callable = lambda _: None # type: ignore + ) -> str: + raise NotImplementedError( + f"{self.__class__.__name__} does not suport starting actions" + ) + + def terminate_action(self, action_handle: str): + raise NotImplementedError( + f"{self.__class__.__name__} does not suport terminating actions" + ) + + def send_and_wait(self, target: str) -> BaseMessage: + raise NotImplementedError( + f"{self.__class__.__name__} does not suport sending messages" + ) + + def cleanup(self): + """Clean up server resources.""" + self.server.shutdown() + self.server.server_close() From 07759941d89c7f3ec89d5eaad0a5672e1def839b Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Thu, 9 Jan 2025 15:42:31 +0100 Subject: [PATCH 04/14] test: ros_connector --- tests/communication/test_ros_connector.py | 158 ++++++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 tests/communication/test_ros_connector.py diff --git a/tests/communication/test_ros_connector.py b/tests/communication/test_ros_connector.py new file mode 100644 index 00000000..c3afe605 --- /dev/null +++ b/tests/communication/test_ros_connector.py @@ -0,0 +1,158 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import time +from typing import List, Optional + +import pytest +import rclpy +from rclpy.callback_groups import ReentrantCallbackGroup +from rclpy.executors import MultiThreadedExecutor +from rclpy.node import Node +from sensor_msgs.msg import Image +from std_msgs.msg import String + +from rai.communication.base_connector import BaseMessage +from rai.communication.ros_connector import ROS2Connector + + +class ReceiverNode(Node): + def __init__(self): + super().__init__("receiver") + self.received_images: List[Image] = [] + self.received_strings: List[String] = [] + self.group = ReentrantCallbackGroup() + self.string_sub = self.create_subscription( # type: ignore + String, "/text", self.on_message, 10, callback_group=self.group + ) + self.image_sub = self.create_subscription( # type: ignore + Image, "/image", self.on_message, 10, callback_group=self.group + ) + + def on_message(self, msg: Optional[String | Image]): + if isinstance(msg, String): + self.received_strings.append(msg) + elif isinstance(msg, Image): + self.received_images.append(msg) + else: + raise ValueError(f"Unknown message type: {type(msg)}") + + +class SenderNode(Node): + def __init__(self): + super().__init__("sender") + self.text_pub = self.create_publisher( # type: ignore + String, + "/text", + 10, + ) + self.image_pub = self.create_publisher( # type: ignore + Image, + "/image", + 10, + ) + self.timer = self.create_timer(0.1, self.on_timer) # type: ignore + + def on_timer(self): + self.text_pub.publish(String(data="Hello, world!")) + self.image_pub.publish(Image()) + + +def test_send_message(): + rclpy.init() + receiver = ReceiverNode() + receiver_executor = MultiThreadedExecutor() + receiver_executor.add_node(receiver) + receiver_thread = threading.Thread(target=receiver_executor.spin) + receiver_thread.start() + + connector = ROS2Connector() + connector.send_message(BaseMessage(String(data="Hello, world!")), "/text") + connector.send_message(BaseMessage(Image()), "/image") + + time.sleep(1.0) + assert len(receiver.received_strings) == 1 + assert len(receiver.received_images) == 1 + + rclpy.shutdown() + receiver_thread.join() + + +def test_receive_message(): + rclpy.init() + sender = SenderNode() + sender_executor = MultiThreadedExecutor() + sender_executor.add_node(sender) + sender_thread = threading.Thread(target=sender_executor.spin) + sender_thread.start() + + connector = ROS2Connector() + message = connector.receive_message("/text") + assert isinstance(message, BaseMessage) + assert isinstance(message.content, String) + + message = connector.receive_message("/image") + assert isinstance(message, BaseMessage) + assert isinstance(message.content, Image) + + rclpy.shutdown() + sender_thread.join() + + +@pytest.mark.skip(reason="Not implemented") +def test_start_action(): + pass + + +@pytest.mark.skip(reason="Not implemented") +def test_terminate_action(): + pass + + +@pytest.mark.skip(reason="Not implemented") +def test_send_and_wait(): + pass + + +def test_connector_cleanup(): + rclpy.init() + connector = ROS2Connector() + + connector.send_message(BaseMessage(String(data="Test")), "/text") + connector.send_message(BaseMessage(Image()), "/image") + + initial_publisher_count = len(connector.publishers) + assert initial_publisher_count == 2 + + connector.send_message(BaseMessage(String(data="Test")), "/text") + assert ( + len(connector.publishers) == initial_publisher_count + ) # reuses existing publishers + + connector.cleanup() + assert not connector.executor_thread.is_alive() + + rclpy.shutdown() + assert not connector.node.context.ok() + + +def test_invalid_topic(): + rclpy.init() + connector = ROS2Connector() + + with pytest.raises(ValueError, match="Topic '/nonexistent' not found"): + connector.receive_message("/nonexistent") + + rclpy.shutdown() From 4db7e1b8609a71be7b8ae92724450862114c75c3 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Thu, 9 Jan 2025 19:56:15 +0100 Subject: [PATCH 05/14] test: split test_connector_cleanup into two tests --- tests/communication/test_ros_connector.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/communication/test_ros_connector.py b/tests/communication/test_ros_connector.py index c3afe605..d1262f8a 100644 --- a/tests/communication/test_ros_connector.py +++ b/tests/communication/test_ros_connector.py @@ -126,6 +126,20 @@ def test_send_and_wait(): pass +def test_connector_publisher_reuse(): + rclpy.init() + connector = ROS2Connector() + connector.send_message(BaseMessage(String(data="Test")), "/text") + connector.send_message(BaseMessage(Image()), "/image") + + assert len(connector.publishers) == 2 + + connector.send_message(BaseMessage(String(data="Test")), "/text") + assert len(connector.publishers) == 2 + connector.cleanup() + rclpy.shutdown() + + def test_connector_cleanup(): rclpy.init() connector = ROS2Connector() @@ -136,11 +150,6 @@ def test_connector_cleanup(): initial_publisher_count = len(connector.publishers) assert initial_publisher_count == 2 - connector.send_message(BaseMessage(String(data="Test")), "/text") - assert ( - len(connector.publishers) == initial_publisher_count - ) # reuses existing publishers - connector.cleanup() assert not connector.executor_thread.is_alive() From 8f82fcf5787fb81a38d96f300fb2fb448cba58e2 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Thu, 9 Jan 2025 19:56:52 +0100 Subject: [PATCH 06/14] feat: allow custom qos profile --- src/rai/rai/communication/ros_connector.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/rai/rai/communication/ros_connector.py b/src/rai/rai/communication/ros_connector.py index e841596a..3354f26e 100644 --- a/src/rai/rai/communication/ros_connector.py +++ b/src/rai/rai/communication/ros_connector.py @@ -14,7 +14,7 @@ import atexit import threading -from typing import Callable, Dict +from typing import Callable, Dict, Optional import rclpy from rclpy.executors import MultiThreadedExecutor @@ -33,13 +33,17 @@ class ROS2Connector(BaseConnector): - def __init__(self): + def __init__( + self, + node_name: str = "rai_ros2_connector", + qos_profile: Optional[QoSProfile] = None, + ): if not rclpy.ok(): rclpy.init() - self.node = Node(node_name="rai_ros2_connector") + self.node = Node(node_name=node_name) self.publishers: Dict[str, Publisher] = {} - self.default_qos_profile = QoSProfile( + self.qos_profile = qos_profile or QoSProfile( depth=10, history=QoSHistoryPolicy.KEEP_LAST, durability=QoSDurabilityPolicy.SYSTEM_DEFAULT, @@ -56,7 +60,7 @@ def send_message(self, msg: BaseMessage, target: str) -> None: publisher = self.publishers.get(target) if publisher is None: self.publishers[target] = self.node.create_publisher( - msg.msg_type, target, qos_profile=self.default_qos_profile + msg.msg_type, target, qos_profile=self.qos_profile ) self.publishers[target].publish(msg.content) @@ -78,7 +82,7 @@ def receive_message(self, source: str) -> BaseMessage: import_message_from_str(msg_type), self.node, source, - qos_profile=self.default_qos_profile, + qos_profile=self.qos_profile, ) if status: From 269fb1cbf85a111bdb55c7f955f5a4b76d7e37d6 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Fri, 10 Jan 2025 10:19:46 +0100 Subject: [PATCH 07/14] feat: add destroy_publisher method --- src/rai/rai/communication/ros_connector.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/rai/rai/communication/ros_connector.py b/src/rai/rai/communication/ros_connector.py index 3354f26e..a8386174 100644 --- a/src/rai/rai/communication/ros_connector.py +++ b/src/rai/rai/communication/ros_connector.py @@ -107,6 +107,14 @@ def send_and_wait(self, target: str) -> BaseMessage: f"{self.__class__.__name__} does not suport sending messages" ) + def destroy_publisher(self, target: str): + publisher = self.publishers.get(target) + if publisher is not None: + publisher.destroy() + self.publishers.pop(target) + else: + raise ValueError(f"Publisher for {target} not found") + def cleanup(self): self.executor.shutdown() self.executor_thread.join() From 4f76da4d0e54d6ced157dce1e822b2351f672516 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Fri, 10 Jan 2025 19:54:16 +0100 Subject: [PATCH 08/14] refactor(test_ros_connector): use topic names based on test name, move related ROS 2 Node code inside relevant test cases --- tests/communication/test_ros_connector.py | 118 ++++++++++++---------- 1 file changed, 67 insertions(+), 51 deletions(-) diff --git a/tests/communication/test_ros_connector.py b/tests/communication/test_ros_connector.py index d1262f8a..64c1a8c6 100644 --- a/tests/communication/test_ros_connector.py +++ b/tests/communication/test_ros_connector.py @@ -28,49 +28,36 @@ from rai.communication.ros_connector import ROS2Connector -class ReceiverNode(Node): - def __init__(self): - super().__init__("receiver") - self.received_images: List[Image] = [] - self.received_strings: List[String] = [] - self.group = ReentrantCallbackGroup() - self.string_sub = self.create_subscription( # type: ignore - String, "/text", self.on_message, 10, callback_group=self.group - ) - self.image_sub = self.create_subscription( # type: ignore - Image, "/image", self.on_message, 10, callback_group=self.group - ) - - def on_message(self, msg: Optional[String | Image]): - if isinstance(msg, String): - self.received_strings.append(msg) - elif isinstance(msg, Image): - self.received_images.append(msg) - else: - raise ValueError(f"Unknown message type: {type(msg)}") - - -class SenderNode(Node): - def __init__(self): - super().__init__("sender") - self.text_pub = self.create_publisher( # type: ignore - String, - "/text", - 10, - ) - self.image_pub = self.create_publisher( # type: ignore - Image, - "/image", - 10, - ) - self.timer = self.create_timer(0.1, self.on_timer) # type: ignore - - def on_timer(self): - self.text_pub.publish(String(data="Hello, world!")) - self.image_pub.publish(Image()) - - def test_send_message(): + class ReceiverNode(Node): + def __init__(self): + super().__init__("receiver") + self.received_images: List[Image] = [] + self.received_strings: List[String] = [] + self.group = ReentrantCallbackGroup() + self.string_sub = self.create_subscription( # type: ignore + String, + "/test_send_message/text", + self.on_message, + 10, + callback_group=self.group, + ) + self.image_sub = self.create_subscription( # type: ignore + Image, + "/test_send_message/image", + self.on_message, + 10, + callback_group=self.group, + ) + + def on_message(self, msg: Optional[String | Image]): + if isinstance(msg, String): + self.received_strings.append(msg) + elif isinstance(msg, Image): + self.received_images.append(msg) + else: + raise ValueError(f"Unknown message type: {type(msg)}") + rclpy.init() receiver = ReceiverNode() receiver_executor = MultiThreadedExecutor() @@ -79,8 +66,10 @@ def test_send_message(): receiver_thread.start() connector = ROS2Connector() - connector.send_message(BaseMessage(String(data="Hello, world!")), "/text") - connector.send_message(BaseMessage(Image()), "/image") + connector.send_message( + BaseMessage(String(data="Hello, world!")), "/test_send_message/text" + ) + connector.send_message(BaseMessage(Image()), "/test_send_message/image") time.sleep(1.0) assert len(receiver.received_strings) == 1 @@ -91,6 +80,25 @@ def test_send_message(): def test_receive_message(): + class SenderNode(Node): + def __init__(self): + super().__init__("sender") + self.text_pub = self.create_publisher( # type: ignore + String, + "/test_receive_message/text", + 10, + ) + self.image_pub = self.create_publisher( # type: ignore + Image, + "/test_receive_message/image", + 10, + ) + self.timer = self.create_timer(0.1, self.on_timer) # type: ignore + + def on_timer(self): + self.text_pub.publish(String(data="Hello, world!")) + self.image_pub.publish(Image()) + rclpy.init() sender = SenderNode() sender_executor = MultiThreadedExecutor() @@ -99,11 +107,11 @@ def test_receive_message(): sender_thread.start() connector = ROS2Connector() - message = connector.receive_message("/text") + message = connector.receive_message("/test_receive_message/text") assert isinstance(message, BaseMessage) assert isinstance(message.content, String) - message = connector.receive_message("/image") + message = connector.receive_message("/test_receive_message/image") assert isinstance(message, BaseMessage) assert isinstance(message.content, Image) @@ -129,12 +137,18 @@ def test_send_and_wait(): def test_connector_publisher_reuse(): rclpy.init() connector = ROS2Connector() - connector.send_message(BaseMessage(String(data="Test")), "/text") - connector.send_message(BaseMessage(Image()), "/image") + connector.send_message( + BaseMessage(String(data="Test")), "/test_connector_publisher_reuse/text" + ) + connector.send_message( + BaseMessage(Image()), "/test_connector_publisher_reuse/image" + ) assert len(connector.publishers) == 2 - connector.send_message(BaseMessage(String(data="Test")), "/text") + connector.send_message( + BaseMessage(String(data="Test")), "/test_connector_publisher_reuse/text" + ) assert len(connector.publishers) == 2 connector.cleanup() rclpy.shutdown() @@ -144,8 +158,10 @@ def test_connector_cleanup(): rclpy.init() connector = ROS2Connector() - connector.send_message(BaseMessage(String(data="Test")), "/text") - connector.send_message(BaseMessage(Image()), "/image") + connector.send_message( + BaseMessage(String(data="Test")), "/test_connector_cleanup/text" + ) + connector.send_message(BaseMessage(Image()), "/test_connector_cleanup/image") initial_publisher_count = len(connector.publishers) assert initial_publisher_count == 2 From d7f8c606259db3251ffff2a57b38afd1bbf25ded Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Fri, 10 Jan 2025 19:55:16 +0100 Subject: [PATCH 09/14] refactor(test_transport): use function name based topics, move ROS 2 Node class into test code --- tests/messages/test_transport.py | 63 +++++++++++++++++--------------- 1 file changed, 34 insertions(+), 29 deletions(-) diff --git a/tests/messages/test_transport.py b/tests/messages/test_transport.py index 93a31e5e..89dc2ecb 100644 --- a/tests/messages/test_transport.py +++ b/tests/messages/test_transport.py @@ -57,38 +57,43 @@ def get_qos_profiles() -> List[str]: raise ValueError(f"Invalid ROS_DISTRO: {ros_distro}") -class TestPublisher(Node): - def __init__(self, qos_profile: QoSProfile): - super().__init__("test_publisher_" + str(uuid.uuid4()).replace("-", "")) - - self.image_publisher = self.create_publisher(Image, "image", qos_profile) - self.text_publisher = self.create_publisher(String, "text", qos_profile) - - self.image_timer = self.create_timer(0.1, self.image_callback) - self.text_timer = self.create_timer(0.1, self.text_callback) - - def image_callback(self): - msg = Image() - msg.header.stamp = self.get_clock().now().to_msg() - msg.height = 540 - msg.width = 960 - msg.encoding = "bgr8" - msg.is_bigendian = 0 - msg.step = msg.width * 3 - msg.data = np.random.randint(0, 255, size=msg.height * msg.width * 3).tobytes() - self.image_publisher.publish(msg) - - def text_callback(self): - msg = String() - msg.data = "Hello, world!" - self.text_publisher.publish(msg) - - @pytest.mark.parametrize( "qos_profile", get_qos_profiles(), ) def test_transport(qos_profile: str): + class TestPublisher(Node): + def __init__(self, qos_profile: QoSProfile): + super().__init__("test_publisher_" + str(uuid.uuid4()).replace("-", "")) + + self.image_publisher = self.create_publisher( + Image, "/test_transport/image", qos_profile + ) + self.text_publisher = self.create_publisher( + String, "/test_transport/text", qos_profile + ) + + self.image_timer = self.create_timer(0.1, self.image_callback) + self.text_timer = self.create_timer(0.1, self.text_callback) + + def image_callback(self): + msg = Image() + msg.header.stamp = self.get_clock().now().to_msg() + msg.height = 540 + msg.width = 960 + msg.encoding = "bgr8" + msg.is_bigendian = 0 + msg.step = msg.width * 3 + msg.data = np.random.randint( + 0, 255, size=msg.height * msg.width * 3 + ).tobytes() + self.image_publisher.publish(msg) + + def text_callback(self): + msg = String() + msg.data = "Hello, world!" + self.text_publisher.publish(msg) + if not rclpy.ok(): rclpy.init() publisher = TestPublisher(QoSPresetProfiles.get_from_short_key(qos_profile)) @@ -103,10 +108,10 @@ def test_transport(qos_profile: str): thread2 = threading.Thread(target=rai_base_node.spin) thread2.start() - topics = ["/image", "/text"] + topics = ["/test_transport/image", "/test_transport/text"] try: for topic in topics: - output = rai_base_node.get_raw_message_from_topic(topic, timeout_sec=5.0) + output = rai_base_node.get_raw_message_from_topic(topic, timeout_sec=5) assert not isinstance(output, str), "No message received" finally: executor.shutdown() From 503ac649012abdca84c62bb70a22025bc3c9d3b3 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Mon, 13 Jan 2025 19:02:20 +0100 Subject: [PATCH 10/14] refactor: HRI and RRI connectors --- src/rai/rai/communication/__init__.py | 19 +- src/rai/rai/communication/base_connector.py | 57 ----- src/rai/rai/communication/hri_connector.py | 110 +++++++++ src/rai/rai/communication/http_connector.py | 145 ------------ src/rai/rai/communication/ros_connector.py | 121 ---------- src/rai/rai/communication/ros_connectors.py | 250 ++++++++++++++++++++ src/rai/rai/communication/rri_connector.py | 62 +++++ 7 files changed, 435 insertions(+), 329 deletions(-) delete mode 100644 src/rai/rai/communication/base_connector.py create mode 100644 src/rai/rai/communication/hri_connector.py delete mode 100644 src/rai/rai/communication/http_connector.py delete mode 100644 src/rai/rai/communication/ros_connector.py create mode 100644 src/rai/rai/communication/ros_connectors.py create mode 100644 src/rai/rai/communication/rri_connector.py diff --git a/src/rai/rai/communication/__init__.py b/src/rai/rai/communication/__init__.py index f22b8744..63b045ed 100644 --- a/src/rai/rai/communication/__init__.py +++ b/src/rai/rai/communication/__init__.py @@ -12,12 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base_connector import BaseConnector, BaseMessage -from .sound_device_connector import SoundDeviceError, StreamingAudioInputDevice +from .ros_connectors import ROS2HRIConnector, ROS2RRIConnector __all__ = [ - "BaseMessage", - "BaseConnector", - "StreamingAudioInputDevice", - "SoundDeviceError", + "ROS2HRIConnector", + "ROS2RRIConnector", ] + +# from .base_connector import BaseConnector, BaseMessage +# from .sound_device_connector import SoundDeviceError, StreamingAudioInputDevice + +# __all__ = [ +# "BaseMessage", +# "BaseConnector", +# "StreamingAudioInputDevice", +# "SoundDeviceError", +# ] diff --git a/src/rai/rai/communication/base_connector.py b/src/rai/rai/communication/base_connector.py deleted file mode 100644 index 996de151..00000000 --- a/src/rai/rai/communication/base_connector.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod -from typing import Any, Callable -from uuid import uuid4 - - -class BaseMessage(ABC): - def __init__(self, content: Any): - self.content = content - - def __repr__(self): - return f"{self.__class__.__name__}({self.content=})" - - @property - def msg_type(self) -> Any: - return type(self.content) - - -class BaseConnector(ABC): - - def _generate_handle(self) -> str: - return str(uuid4()) - - @abstractmethod - def send_message(self, msg: BaseMessage, target: str) -> None: - pass - - @abstractmethod - def receive_message(self, source: str) -> BaseMessage: - pass - - @abstractmethod - def send_and_wait(self, target: str) -> BaseMessage: - pass - - @abstractmethod - def start_action( - self, target: str, on_feedback: Callable, on_finish: Callable = lambda _: None - ) -> str: - pass - - @abstractmethod - def terminate_action(self, action_handle: str): - pass diff --git a/src/rai/rai/communication/hri_connector.py b/src/rai/rai/communication/hri_connector.py new file mode 100644 index 00000000..d28fef75 --- /dev/null +++ b/src/rai/rai/communication/hri_connector.py @@ -0,0 +1,110 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Annotated, List, Literal, Optional, Sequence + +from langchain_core.messages import AIMessage +from langchain_core.messages import BaseMessage as LangchainBaseMessage +from langchain_core.messages import HumanMessage + +from rai.messages import AiMultimodalMessage, HumanMultimodalMessage +from rai.messages.multimodal import MultimodalMessage as RAIMultimodalMessage + + +class HRIMessage: + type: Literal["ai", "human"] + text: str + images: Optional[Annotated[List[str], "base64 encoded png images"]] + audios: Optional[Annotated[List[str], "base64 encoded wav audio"]] + + def __repr__(self): + return f"HRIMessage(type={self.type}, text={self.text}, images={self.images}, audios={self.audios})" + + def __init__( + self, + type: Literal["ai", "human"], + text: str, + images: Optional[List[str]] = None, + audios: Optional[List[str]] = None, + ): + self.type = type + self.text = text + self.images = images + self.audios = audios + + def to_langchain(self) -> LangchainBaseMessage: + match self.type: + case "human": + if self.images is None and self.audios is None: + return HumanMessage(content=self.text) + return HumanMultimodalMessage( + content=self.text, images=self.images, audios=self.audios + ) + case "ai": + if self.images is None and self.audios is None: + return AIMessage(content=self.text) + return AiMultimodalMessage( + content=self.text, images=self.images, audios=self.audios + ) + case _: + raise ValueError( + f"Invalid message type: {self.type} for {self.__class__.__name__}" + ) + + @classmethod + def from_langchain( + cls, + message: LangchainBaseMessage | RAIMultimodalMessage, + ) -> "HRIMessage": + if isinstance(message, RAIMultimodalMessage): + text = str(message.content["text"]) + images = message.images + audios = message.audios + else: + text = str(message.content) + images = None + audios = None + if message.type not in ["ai", "human"]: + raise ValueError(f"Invalid message type: {message.type} for {cls.__name__}") + return cls( + type=message.type, # type: ignore + text=text, + images=images, + audios=audios, + ) + + +class HRIConnector(ABC): + """ + Base class for Human-Robot Interaction (HRI) connectors. + Used for sending and receiving messages between human and robot from various sources. + """ + + configured_targets: Sequence[str] + configured_sources: Sequence[str] + + def build_message( + self, + message: LangchainBaseMessage | RAIMultimodalMessage, + ) -> HRIMessage: + return HRIMessage.from_langchain(message) + + @abstractmethod + def send_message(self, message: LangchainBaseMessage | RAIMultimodalMessage): + pass + + @abstractmethod + def receive_message(self) -> LangchainBaseMessage | RAIMultimodalMessage: + pass diff --git a/src/rai/rai/communication/http_connector.py b/src/rai/rai/communication/http_connector.py deleted file mode 100644 index 8fbda52a..00000000 --- a/src/rai/rai/communication/http_connector.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import atexit -import json -import threading -from datetime import datetime -from http.server import HTTPServer, SimpleHTTPRequestHandler -from typing import Any, Callable, Dict, List - -from rai.communication.base_connector import BaseConnector, BaseMessage - - -class MessageHandler(SimpleHTTPRequestHandler): - """Handler for HTTP requests serving a simple message viewer.""" - - def do_GET(self): - """Serve either the main HTML page or message data.""" - if self.path == "/": - self._serve_html() - elif self.path == "/messages": - self._serve_messages() - - def _serve_html(self): - """Serve the main HTML interface.""" - self.send_response(200) - self.send_header("Content-type", "text/html") - self.end_headers() - self.wfile.write(self.server.get_html_content().encode()) - - def _serve_messages(self): - """Serve message data as JSON.""" - self.send_response(200) - self.send_header("Content-type", "application/json") - self.end_headers() - self.wfile.write(json.dumps(self.server.messages).encode()) - - -class SimpleHTTPServer(HTTPServer): - """HTTP server that maintains a list of messages.""" - - def __init__(self, server_address, RequestHandlerClass): - super().__init__(server_address, RequestHandlerClass) - self.messages: List[Dict[str, Any]] = [] - - def add_message(self, message: str): - """Add a new message with timestamp.""" - self.messages.append( - {"timestamp": datetime.now().isoformat(), "content": message} - ) - - def get_html_content(self) -> str: - """Return the HTML content for the web interface.""" - return """ - - - - Simple Message Viewer - - - -

Messages

-
- - - - """ - - -class HTTPConnector(BaseConnector): - """Connector that displays messages via a web interface.""" - - def __init__(self, host: str = "localhost", port: int = 8000): - self.host = host - self.port = port - - self.server = SimpleHTTPServer((self.host, self.port), MessageHandler) - self.server_thread = threading.Thread(target=self.server.serve_forever) - self.server_thread.daemon = True - self.server_thread.start() - - atexit.register(self.cleanup) - print(f"Server started at http://{self.host}:{self.port}") - - def send_message(self, msg: BaseMessage, target: str) -> None: - """Add message to the web interface.""" - self.server.add_message(str(msg.content)) - - def receive_message(self, source: str) -> BaseMessage: - raise NotImplementedError( - f"{self.__class__.__name__} does not support receiving messages" - ) - - def start_action( - self, target: str, on_feedback: Callable, on_finish: Callable = lambda _: None # type: ignore - ) -> str: - raise NotImplementedError( - f"{self.__class__.__name__} does not suport starting actions" - ) - - def terminate_action(self, action_handle: str): - raise NotImplementedError( - f"{self.__class__.__name__} does not suport terminating actions" - ) - - def send_and_wait(self, target: str) -> BaseMessage: - raise NotImplementedError( - f"{self.__class__.__name__} does not suport sending messages" - ) - - def cleanup(self): - """Clean up server resources.""" - self.server.shutdown() - self.server.server_close() diff --git a/src/rai/rai/communication/ros_connector.py b/src/rai/rai/communication/ros_connector.py deleted file mode 100644 index a8386174..00000000 --- a/src/rai/rai/communication/ros_connector.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import atexit -import threading -from typing import Callable, Dict, Optional - -import rclpy -from rclpy.executors import MultiThreadedExecutor -from rclpy.node import Node -from rclpy.publisher import Publisher -from rclpy.qos import ( - QoSDurabilityPolicy, - QoSHistoryPolicy, - QoSProfile, - QoSReliabilityPolicy, -) - -from rai.communication.base_connector import BaseConnector, BaseMessage -from rai.tools.ros.utils import import_message_from_str -from rai.tools.utils import wait_for_message - - -class ROS2Connector(BaseConnector): - def __init__( - self, - node_name: str = "rai_ros2_connector", - qos_profile: Optional[QoSProfile] = None, - ): - if not rclpy.ok(): - rclpy.init() - - self.node = Node(node_name=node_name) - self.publishers: Dict[str, Publisher] = {} - self.qos_profile = qos_profile or QoSProfile( - depth=10, - history=QoSHistoryPolicy.KEEP_LAST, - durability=QoSDurabilityPolicy.SYSTEM_DEFAULT, - reliability=QoSReliabilityPolicy.RELIABLE, - ) - - self.executor = MultiThreadedExecutor() - self.executor.add_node(self.node) - self.executor_thread = threading.Thread(target=self.executor.spin) - self.executor_thread.start() - atexit.register(self.cleanup) - - def send_message(self, msg: BaseMessage, target: str) -> None: - publisher = self.publishers.get(target) - if publisher is None: - self.publishers[target] = self.node.create_publisher( - msg.msg_type, target, qos_profile=self.qos_profile - ) - self.publishers[target].publish(msg.content) - - def _validate_and_get_msg_type(self, topic: str): - """ - Validate that the topic exists and return the message type. - """ - topic_names_and_types = self.node.get_topic_names_and_types() - topic_names = [topic for topic, _ in topic_names_and_types] - if topic not in topic_names: - raise ValueError( - f"Topic '{topic}' not found. Available topics: {topic_names}" - ) - return topic_names_and_types[topic_names.index(topic)][1][0] - - def receive_message(self, source: str) -> BaseMessage: - msg_type = self._validate_and_get_msg_type(source) - status, msg = wait_for_message( - import_message_from_str(msg_type), - self.node, - source, - qos_profile=self.qos_profile, - ) - - if status: - return BaseMessage(content=msg) - else: - raise ValueError(f"No message found for {source}") - - def start_action( - self, target: str, on_feedback: Callable, on_finish: Callable = lambda _: None - ) -> str: - raise NotImplementedError( - f"{self.__class__.__name__} does not suport starting actions" - ) - - def terminate_action(self, action_handle: str): - raise NotImplementedError( - f"{self.__class__.__name__} does not suport terminating actions" - ) - - def send_and_wait(self, target: str) -> BaseMessage: - raise NotImplementedError( - f"{self.__class__.__name__} does not suport sending messages" - ) - - def destroy_publisher(self, target: str): - publisher = self.publishers.get(target) - if publisher is not None: - publisher.destroy() - self.publishers.pop(target) - else: - raise ValueError(f"Publisher for {target} not found") - - def cleanup(self): - self.executor.shutdown() - self.executor_thread.join() - self.node.destroy_node() diff --git a/src/rai/rai/communication/ros_connectors.py b/src/rai/rai/communication/ros_connectors.py new file mode 100644 index 00000000..e39f1afc --- /dev/null +++ b/src/rai/rai/communication/ros_connectors.py @@ -0,0 +1,250 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import logging +from typing import Any, Callable, Dict, List, Sequence +from uuid import uuid4 + +import rclpy +import rclpy.callback_groups +import rclpy.executors +import rclpy.node +import rclpy.qos +import rclpy.subscription +import rclpy.task +import rosidl_runtime_py.set_message +import rosidl_runtime_py.utilities +from langchain_core.messages import BaseMessage as LangchainBaseMessage +from rclpy.action.client import ActionClient as ROS2ActionClient +from rclpy.node import Node +from rclpy.publisher import Publisher +from rclpy.qos import qos_profile_default +from rosidl_runtime_py.utilities import get_namespaced_type +from std_msgs.msg import String + +from rai.communication.hri_connector import HRIConnector, HRIMessage +from rai.communication.rri_connector import ROS2RRIMessage, RRIConnector, RRIMessage +from rai.messages.multimodal import MultimodalMessage as RAIMultimodalMessage +from rai.tools.ros.utils import import_message_from_str, wait_for_message + + +class ServiceCaller: + node: Node + + def _service_call( + self, message: ROS2RRIMessage, target: str, timeout_sec: float = 1.0 + ) -> ROS2RRIMessage: + response = self.__service_call( + target, message.ros_message_type, message.payload + ) + if isinstance(response, str): + return ROS2RRIMessage( + payload=response, ros_message_type="str", python_message_class=str + ) + else: + return ROS2RRIMessage( + payload=response, + ros_message_type=message.ros_message_type, + python_message_class=type(response), + ) + + def _build_request(self, service_type: str, request_args: Dict[str, Any]) -> Any: + srv_module, _, srv_name = service_type.split("/") + srv_class = getattr(importlib.import_module(f"{srv_module}.srv"), srv_name) + request = srv_class.Request() + rosidl_runtime_py.set_message.set_message_fields(request, request_args) + return request + + def __service_call( # TODO: refactor into single _ method + self, service_name: str, service_type: str, request_args: Dict[str, Any] + ) -> str | object: + if not service_name.startswith("/"): + service_name = f"/{service_name}" + + try: + request = self._build_request(service_type, request_args) + except Exception as e: + return f"Failed to build service request: {e}" + namespaced_type = get_namespaced_type(service_type) + client = self.node.create_client( + rosidl_runtime_py.import_message.import_message_from_namespaced_type( + namespaced_type + ), + service_name, + ) + + if not client.wait_for_service(timeout_sec=1.0): + return f"Service '{service_name}' is not available" + + future = client.call_async(request) + rclpy.spin_until_future_complete(self.node, future) + + if future.result() is not None: + return future.result() + else: + return f"Service call to '{service_name}' failed" + + +class TopicSubscriber: + node: Node + + def _receive_message( + self, source: str, timeout_sec: float = 1.0 + ) -> str | object: # TODO: ROS 2 msg type + publishers_info = self.node.get_publishers_info_by_topic(topic_name=source) + if len(publishers_info) == 0: + return f"No publisher found for topic {source}." + + msg_type_str = publishers_info[0].topic_type + if len({publisher_info.topic_type for publisher_info in publishers_info}) > 1: + logging.warning( + f"Multiple publishers on topic {source} with different message types. Will use the first one." + ) + + msg_type = import_message_from_str(msg_type_str) + status, ros2_msg = wait_for_message( + msg_type=msg_type, + node=self.node, + topic=source, + time_to_wait=int(timeout_sec), + ) + if status: + return ros2_msg + else: + return ( + f"Message could not be received from {source}. Try increasing timeout." + ) + + +class TopicPublisher: + node: Node + publishers: Dict[str, Publisher] = {} + + def _publish( + self, + message: ROS2RRIMessage, + topic: str, + ): + msg_python_type = import_message_from_str(message.ros_message_type) + if topic not in self.publishers: + # TODO: use qos matching when available + self.publishers[topic] = self.node.create_publisher( + msg_python_type, topic, qos_profile_default + ) + ros2_msg = msg_python_type() + rosidl_runtime_py.set_message.set_message_fields(ros2_msg, message.payload) + self.publishers[topic].publish(ros2_msg) + + +class ActionClient: + node: Node + action_clients: Dict[str, ROS2ActionClient] = {} + + def _generate_handle(self) -> str: + return str(uuid4()) + + def _start_action( + self, + action: ROS2RRIMessage, + target: str, + on_feedback: Callable[[Any], None], + on_done: Callable[[Any], None], + timeout_sec: float = 1.0, + ) -> str: + raise NotImplementedError("Action client not implemented") + + def _terminate_action(self, action_handle: str) -> ROS2RRIMessage: + raise NotImplementedError("Action termination not implemented") + + +class ROS2RRIConnector( + RRIConnector, + TopicPublisher, + TopicSubscriber, + ServiceCaller, + ActionClient, +): + def __init__(self, node: Node): + self.node = node + + def send_message( + self, + message: ROS2RRIMessage, + target: str, + ): + self._publish(message, target) + + def receive_message(self, source: str, timeout_sec: float = 1.0) -> ROS2RRIMessage: + ros2_msg = self._receive_message(source, timeout_sec) + if isinstance(ros2_msg, str): + return ROS2RRIMessage( + payload=ros2_msg, ros_message_type="str", python_message_class=str + ) + else: + return ROS2RRIMessage( + payload=ros2_msg, + ros_message_type=type(ros2_msg).__name__, + python_message_class=type(ros2_msg), + ) + + def service_call( + self, message: RRIMessage, target: str, timeout_sec: float = 1.0 + ) -> RRIMessage: + if not isinstance(message, ROS2RRIMessage): + raise TypeError("Message must be of type ROS2RRIMessage") + return self._service_call(message, target, timeout_sec) + + def start_action( + self, + action: RRIMessage, + target: str, + on_feedback: Callable[[Any], None], + on_done: Callable[[Any], None], + timeout_sec: float = 1.0, + ) -> str: + if not isinstance(action, ROS2RRIMessage): + raise TypeError("Action must be of type ROS2RRIMessage") + return self._start_action(action, target, on_feedback, on_done, timeout_sec) + + def terminate_action(self, action_handle: str) -> ROS2RRIMessage: + return self._terminate_action(action_handle) + + +class ROS2HRIConnector(HRIConnector, TopicSubscriber, TopicPublisher): + def __init__(self, node: Node, sources: Sequence[str], targets: Sequence[str]): + self.node = node + self.sources = sources + self.targets = targets + + def send_message(self, message: LangchainBaseMessage | RAIMultimodalMessage): + for target in self.targets: + hri_message = HRIMessage.from_langchain(message) + ros2rri_message = ROS2RRIMessage( + payload=hri_message.text, # TODO: Only string topics + ros_message_type=type(hri_message).__name__, + python_message_class=type(hri_message), + ) + self._publish(ros2rri_message, target) + + def receive_message(self) -> LangchainBaseMessage | RAIMultimodalMessage: + messages: List[Dict[str, str]] = [] + for source in self.sources: + ros2_msg = self._receive_message(source) + if not isinstance(ros2_msg, String): + raise ValueError( + f"Received message from {source} is not a string. Only string topics are supported." + ) + messages.append(ros2_msg) + return HRIMessage(text=str(messages), type="human").to_langchain() diff --git a/src/rai/rai/communication/rri_connector.py b/src/rai/rai/communication/rri_connector.py new file mode 100644 index 00000000..84c74339 --- /dev/null +++ b/src/rai/rai/communication/rri_connector.py @@ -0,0 +1,62 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Any, Callable, Optional + +from pydantic import BaseModel, Field + + +class RRIMessage(BaseModel): + payload: Any = Field(description="The payload of the message") + + +class ROS2RRIMessage(RRIMessage): + ros_message_type: str = Field( + description="The string representation of the ROS message type (e.g. 'std_msgs/String')" + ) + python_message_class: Optional[type] = Field( + description="The Python class of the ROS message type", default=None + ) + + +class RRIConnector(ABC): + """ + Base class for Robot-Robot Interaction (RRI) connectors. + """ + + @abstractmethod + def send_message(self, message: Any, target: str): + pass + + @abstractmethod + def receive_message(self, source: str, timeout_sec: float = 1.0) -> RRIMessage: + pass + + @abstractmethod + def service_call( + self, message: RRIMessage, target: str, timeout_sec: float = 1.0 + ) -> RRIMessage: + pass + + @abstractmethod + def start_action( + self, + action: RRIMessage, + target: str, + on_feedback: Callable[[Any], None], + on_done: Callable[[Any], None], + timeout_sec: float = 1.0, + ) -> str: + pass From 4c238e5a0de02e2822c9849e34d49a264bbc9f0a Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Mon, 13 Jan 2025 19:19:55 +0100 Subject: [PATCH 11/14] fix: string messages --- src/rai/rai/communication/ros_connectors.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/rai/rai/communication/ros_connectors.py b/src/rai/rai/communication/ros_connectors.py index e39f1afc..ca6645b9 100644 --- a/src/rai/rai/communication/ros_connectors.py +++ b/src/rai/rai/communication/ros_connectors.py @@ -232,9 +232,9 @@ def send_message(self, message: LangchainBaseMessage | RAIMultimodalMessage): for target in self.targets: hri_message = HRIMessage.from_langchain(message) ros2rri_message = ROS2RRIMessage( - payload=hri_message.text, # TODO: Only string topics - ros_message_type=type(hri_message).__name__, - python_message_class=type(hri_message), + payload={"data": hri_message.text}, # TODO: Only string topics + ros_message_type="std_msgs/msg/String", + python_message_class=String, ) self._publish(ros2rri_message, target) From 17dde2b8839acb40a66289dd10a911295bb5840a Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Tue, 14 Jan 2025 16:04:48 +0100 Subject: [PATCH 12/14] chore: move ros2 apis to communication/ros folder --- src/rai/rai/{ros2_apis.py => communication/ros2/api.py} | 0 src/rai/rai/node.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename src/rai/rai/{ros2_apis.py => communication/ros2/api.py} (100%) diff --git a/src/rai/rai/ros2_apis.py b/src/rai/rai/communication/ros2/api.py similarity index 100% rename from src/rai/rai/ros2_apis.py rename to src/rai/rai/communication/ros2/api.py diff --git a/src/rai/rai/node.py b/src/rai/rai/node.py index fcc2d3ae..d1d76f8a 100644 --- a/src/rai/rai/node.py +++ b/src/rai/rai/node.py @@ -34,7 +34,7 @@ from rai.agents.state_based import Report, State, create_state_based_agent from rai.messages import HumanMultimodalMessage -from rai.ros2_apis import Ros2ActionsAPI, Ros2TopicsAPI +from rai.communication.ros2.api import Ros2ActionsAPI, Ros2TopicsAPI from rai.tools.ros.native import Ros2BaseTool from rai.tools.ros.utils import convert_ros_img_to_base64 from rai.utils.model_initialization import get_llm_model, get_tracing_callbacks From 66c1494471e43ab16e130fff2208fe308b6f6461 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Tue, 14 Jan 2025 21:43:43 +0100 Subject: [PATCH 13/14] feat: implement ROS2TopicAPI MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Bartłomiej Boczek --- src/rai/rai/communication/ros2/api.py | 508 ++++++++++---------------- src/rai/rai/tools/ros/utils.py | 4 +- 2 files changed, 196 insertions(+), 316 deletions(-) diff --git a/src/rai/rai/communication/ros2/api.py b/src/rai/rai/communication/ros2/api.py index 359885f0..61ed6759 100644 --- a/src/rai/rai/communication/ros2/api.py +++ b/src/rai/rai/communication/ros2/api.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools -import time -from typing import Any, Dict, List, Optional, Tuple, Type, Union +import logging +from typing import Any, Dict, List, Optional, Tuple, Type import rclpy import rclpy.callback_groups @@ -25,8 +24,7 @@ import rclpy.task import rosidl_runtime_py.set_message import rosidl_runtime_py.utilities -from action_msgs.msg import GoalStatus -from rclpy.action.client import ActionClient +from rclpy.publisher import Publisher from rclpy.qos import ( DurabilityPolicy, HistoryPolicy, @@ -36,342 +34,224 @@ ) from rclpy.topic_endpoint_info import TopicEndpointInfo -from rai.tools.ros.utils import import_message_from_str -from rai.utils.ros import NodeDiscovery -from rai.utils.ros_async import get_future_result +from rai.tools.ros.utils import import_message_from_str, wait_for_message + + +def adapt_requests_to_offers(publisher_info: List[TopicEndpointInfo]) -> QoSProfile: + if not publisher_info: + return QoSProfile(depth=1) + + num_endpoints = len(publisher_info) + reliability_reliable_count = 0 + durability_transient_local_count = 0 + + for endpoint in publisher_info: + profile = endpoint.qos_profile + if profile.reliability == ReliabilityPolicy.RELIABLE: + reliability_reliable_count += 1 + if profile.durability == DurabilityPolicy.TRANSIENT_LOCAL: + durability_transient_local_count += 1 + + request_qos = QoSProfile( + history=HistoryPolicy.KEEP_LAST, + depth=1, + liveliness=LivelinessPolicy.AUTOMATIC, + ) + + # Set reliability based on publisher offers + if reliability_reliable_count == num_endpoints: + request_qos.reliability = ReliabilityPolicy.RELIABLE + else: + if reliability_reliable_count > 0: + logging.warning( + "Some, but not all, publishers are offering RELIABLE reliability. " + "Falling back to BEST_EFFORT as it will connect to all publishers. " + "Some messages from Reliable publishers could be dropped." + ) + request_qos.reliability = ReliabilityPolicy.BEST_EFFORT + + # Set durability based on publisher offers + if durability_transient_local_count == num_endpoints: + request_qos.durability = DurabilityPolicy.TRANSIENT_LOCAL + else: + if durability_transient_local_count > 0: + logging.warning( + "Some, but not all, publishers are offering TRANSIENT_LOCAL durability. " + "Falling back to VOLATILE as it will connect to all publishers. " + "Previously-published latched messages will not be retrieved." + ) + request_qos.durability = DurabilityPolicy.VOLATILE + return request_qos -def ros2_build_msg(msg_type: str, msg_args: Dict[str, Any]) -> Tuple[object, Type]: - """ - Import message and create it. Return both ready message and message class. - msgs args can have two formats: - { "goal" : {arg 1 : xyz, ... } or {arg 1 : xyz, ... } +def build_ros2_msg(msg_type: str, msg_args: Dict[str, Any]) -> object: + """Build a ROS2 message instance from type string and content dictionary.""" + msg_cls = import_message_from_str(msg_type) + msg = msg_cls() + rosidl_runtime_py.set_message.set_message_fields(msg, msg_args) + return msg + + +class ROS2TopicAPI: + """Handles ROS2 topic operations including publishing and subscribing to messages. + + This class provides a high-level interface for ROS2 topic operations with automatic + QoS profile matching and proper resource management. + + Attributes: + node: The ROS2 node instance + logger: Logger instance for this class + _publishers: Dictionary mapping topic names to their publisher instances """ - msg_cls: Type = rosidl_runtime_py.utilities.get_interface(msg_type) - msg = msg_cls.Goal() + def __init__(self, node: rclpy.node.Node) -> None: + """Initialize the ROS2 topic API. - if "goal" in msg_args: - msg_args = msg_args["goal"] - rosidl_runtime_py.set_message.set_message_fields(msg, msg_args) - return msg, msg_cls + Args: + node: ROS2 node instance to use for communication + """ + self._node = node + self._logger = node.get_logger() + self._publishers: Dict[str, Publisher] = {} + + def list_topics(self) -> List[Tuple[str, List[str]]]: + """Get list of available topics and their types. + Returns: + List of tuples containing (topic_name, list_of_types) + """ + return self._node.get_topic_names_and_types() -class Ros2TopicsAPI: - def __init__( + def publish( self, - node: rclpy.node.Node, - callback_group: rclpy.callback_groups.CallbackGroup, - ros_discovery_info: NodeDiscovery, + topic: str, + msg_content: Dict[str, Any], + msg_type: str, + *, # Force keyword arguments + auto_qos_matching: bool = True, + qos_profile: Optional[QoSProfile] = None, ) -> None: - self.node = node - self.callback_group = callback_group - self.last_subscription_msgs_buffer = dict() - self.qos_profile_cache: Dict[str, QoSProfile] = dict() - - self.ros_discovery_info = ros_discovery_info + """Publish a message to a ROS2 topic. + + Args: + topic: Name of the topic to publish to + msg_content: Dictionary containing the message content + msg_type: ROS2 message type as string (e.g. 'std_msgs/msg/String') + auto_qos_matching: Whether to automatically match QoS with subscribers + qos_profile: Optional custom QoS profile to use + + Raises: + ValueError: If neither auto_qos_matching is True nor qos_profile is provided + """ + qos_profile = self._resolve_qos_profile( + topic, auto_qos_matching, qos_profile, for_publisher=True + ) - def get_logger(self): - return self.node.get_logger() + msg = build_ros2_msg(msg_type, msg_content) + publisher = self._get_or_create_publisher(topic, type(msg), qos_profile) + publisher.publish(msg) - def adapt_requests_to_offers( - self, publisher_info: List[TopicEndpointInfo] - ) -> QoSProfile: - if not publisher_info: - return QoSProfile(depth=1) - - num_endpoints = len(publisher_info) - reliability_reliable_count = 0 - durability_transient_local_count = 0 - - for endpoint in publisher_info: - profile = endpoint.qos_profile - if profile.reliability == ReliabilityPolicy.RELIABLE: - reliability_reliable_count += 1 - if profile.durability == DurabilityPolicy.TRANSIENT_LOCAL: - durability_transient_local_count += 1 - - request_qos = QoSProfile( - history=HistoryPolicy.KEEP_LAST, - depth=1, - liveliness=LivelinessPolicy.AUTOMATIC, - ) + def receive( + self, + topic: str, + msg_type: str, + *, # Force keyword arguments + timeout_sec: float = 1.0, + auto_qos_matching: bool = True, + qos_profile: Optional[QoSProfile] = None, + ) -> Any: + """Receive a single message from a ROS2 topic. - # Set reliability based on publisher offers - if reliability_reliable_count == num_endpoints: - request_qos.reliability = ReliabilityPolicy.RELIABLE - else: - if reliability_reliable_count > 0: - self.get_logger().warning( - "Some, but not all, publishers are offering RELIABLE reliability. " - "Falling back to BEST_EFFORT as it will connect to all publishers. " - "Some messages from Reliable publishers could be dropped." - ) - request_qos.reliability = ReliabilityPolicy.BEST_EFFORT - - # Set durability based on publisher offers - if durability_transient_local_count == num_endpoints: - request_qos.durability = DurabilityPolicy.TRANSIENT_LOCAL - else: - if durability_transient_local_count > 0: - self.get_logger().warning( - "Some, but not all, publishers are offering TRANSIENT_LOCAL durability. " - "Falling back to VOLATILE as it will connect to all publishers. " - "Previously-published latched messages will not be retrieved." - ) - request_qos.durability = DurabilityPolicy.VOLATILE - - return request_qos - - def create_subscription_by_topic_name(self, topic): - if self.has_subscription(topic): - self.get_logger().warning( - f"Subscription to {topic} already exists. To override use destroy_subscription_by_topic_name first" - ) - return + Args: + topic: Name of the topic to receive from + msg_type: ROS2 message type as string + timeout_sec: How long to wait for a message + auto_qos_matching: Whether to automatically match QoS with publishers + qos_profile: Optional custom QoS profile to use - msg_type = self.get_msg_type(topic) + Returns: + The received message - if topic not in self.qos_profile_cache: - self.get_logger().debug(f"Getting qos profile for topic: {topic}") - qos_profile = self.adapt_requests_to_offers( - self.node.get_publishers_info_by_topic(topic) - ) - self.qos_profile_cache[topic] = qos_profile - else: - self.get_logger().debug(f"Using cached qos profile for topic: {topic}") - qos_profile = self.qos_profile_cache[topic] + Raises: + ValueError: If no publisher exists or no message is received within timeout + """ + self._verify_publisher_exists(topic) - topic_callback = functools.partial( - self.generic_state_subscriber_callback, topic + qos_profile = self._resolve_qos_profile( + topic, auto_qos_matching, qos_profile, for_publisher=False ) - self.node.create_subscription( - msg_type, + msg_cls = self._get_message_class(msg_type) + success, msg = wait_for_message( + msg_cls, + self._node, topic, - callback=topic_callback, - callback_group=self.callback_group, qos_profile=qos_profile, + time_to_wait=int(timeout_sec), ) - def get_msg_type(self, topic: str, n_tries: int = 5) -> Any: - """Sometimes node fails to do full discovery, therefore we need to retry""" - for _ in range(n_tries): - if topic in self.ros_discovery_info.topics_and_types: - msg_type = self.ros_discovery_info.topics_and_types[topic] - return import_message_from_str(msg_type) - else: - # Wait for next discovery cycle - self.get_logger().info(f"Waiting for topic: {topic}") - if self.ros_discovery_info: - time.sleep(self.ros_discovery_info.period_sec) - else: - time.sleep(1.0) - raise KeyError(f"Topic {topic} not found") - - def set_ros_discovery_info(self, new_ros_discovery_info: NodeDiscovery): - self.ros_discovery_info = new_ros_discovery_info - - def get_raw_message_from_topic( - self, topic: str, timeout_sec: int = 5, topic_wait_sec: int = 2 - ) -> Any: - self.get_logger().debug(f"Getting msg from topic: {topic}") - - ts = time.perf_counter() - - for _ in range(topic_wait_sec * 10): - if topic not in self.ros_discovery_info.topics_and_types: - time.sleep(0.1) - continue - else: - break - - if topic not in self.ros_discovery_info.topics_and_types: - raise KeyError( - f"Topic {topic} not found. Available topics: {self.ros_discovery_info.topics_and_types.keys()}" + if not success: + raise ValueError( + f"No message received from topic: {topic} within {timeout_sec} seconds" + ) + return msg + + def _get_or_create_publisher( + self, topic: str, msg_cls: Type[Any], qos_profile: QoSProfile + ) -> Publisher: + """Get existing publisher or create a new one if it doesn't exist.""" + if topic not in self._publishers: + self._publishers[topic] = self._node.create_publisher( # type: ignore + msg_cls, topic, qos_profile=qos_profile ) + return self._publishers[topic] - if topic in self.last_subscription_msgs_buffer: - self.get_logger().info("Returning cached message") - return self.last_subscription_msgs_buffer[topic] - else: - self.create_subscription_by_topic_name(topic) - try: - msg = self.last_subscription_msgs_buffer.get(topic, None) - while msg is None and time.perf_counter() - ts < timeout_sec: - msg = self.last_subscription_msgs_buffer.get(topic, None) - self.get_logger().info("Waiting for message...") - time.sleep(0.1) - - success = msg is not None - - if success: - self.get_logger().debug( - f"Received message of type {type(msg)} from topic {topic}" - ) - return msg - else: - error = f"No message received in {timeout_sec} seconds from topic {topic}" - self.get_logger().error(error) - return error - finally: - self.destroy_subscription_by_topic_name(topic) - - def generic_state_subscriber_callback(self, topic_name: str, msg: Any): - self.get_logger().debug( - f"Received message of type {type(msg)} from topic {topic_name}" - ) - self.last_subscription_msgs_buffer[topic_name] = msg - - def has_subscription(self, topic: str) -> bool: - for sub in self.node._subscriptions: - if sub.topic == topic: - return True - return False - - def destroy_subscription_by_topic_name(self, topic: str): - self.last_subscription_msgs_buffer.clear() - for sub in self.node._subscriptions: - if sub.topic == topic: - self.node.destroy_subscription(sub) - - -class Ros2ActionsAPI: - def __init__(self, node: rclpy.node.Node): - self.node = node - - self.goal_handle = None - self.result_future = None - self.feedback = None - self.status: Optional[int] = None - self.client: Optional[ActionClient] = None - self.action_feedback: Optional[Any] = None - - def get_logger(self): - return self.node.get_logger() - - def run_action( - self, action_name: str, action_type: str, action_goal_args: Dict[str, Any] - ): - if not self.is_task_complete(): - raise AssertionError( - "Another ros2 action is currently running and parallel actions are not supported. Please wait until the previous action is complete before starting a new one. You can also cancel the current one." + def _resolve_qos_profile( + self, + topic: str, + auto_qos_matching: bool, + qos_profile: Optional[QoSProfile], + for_publisher: bool, + ) -> QoSProfile: + """Resolve which QoS profile to use based on settings and existing endpoints.""" + if auto_qos_matching and qos_profile is not None: + self._logger.warning( # type: ignore + "Auto QoS matching is enabled, but qos_profile is provided. " + "Using provided qos_profile." ) + return qos_profile - if action_name[0] != "/": - action_name = "/" + action_name - self.get_logger().info(f"Action name corrected to: {action_name}") - - try: - goal_msg, msg_cls = ros2_build_msg(action_type, action_goal_args) - except Exception as e: - return f"Failed to build message: {e}" - - self.client = ActionClient(self.node, msg_cls, action_name) - self.msg_cls = msg_cls - - retries = 0 - while not self.client.wait_for_server(timeout_sec=1.0): - retries += 1 - if retries > 5: - raise Exception( - f"Action server '{action_name}' is not available. Make sure `action_name` is correct..." - ) - self.get_logger().info( - f"'{action_name}' action server not available, waiting..." + if auto_qos_matching: + endpoint_info = ( + self._node.get_subscriptions_info_by_topic(topic) + if for_publisher + else self._node.get_publishers_info_by_topic(topic) ) + return adapt_requests_to_offers(endpoint_info) - self.get_logger().info(f"Sending action message: {goal_msg}") + if qos_profile is not None: + return qos_profile - send_goal_future = self.client.send_goal_async( - goal_msg, self._feedback_callback + raise ValueError( + "Either auto_qos_matching must be True or qos_profile must be provided" ) - self.get_logger().info("Action goal sent!") - - self.goal_handle = get_future_result(send_goal_future) - - if not self.goal_handle: - raise Exception(f"Action '{action_name}' not sent to server") - - if not self.goal_handle.accepted: - raise Exception(f"Action '{action_name}' not accepted by server") - - self.result_future = self.goal_handle.get_result_async() - self.get_logger().info("Action sent!") - return f"{action_name} started successfully with args: {action_goal_args}" - - def get_task_result(self) -> str: - if not self.is_task_complete(): - return "Task is not complete yet" - - def parse_status(status: int) -> str: - return { - GoalStatus.STATUS_SUCCEEDED: "succeeded", - GoalStatus.STATUS_ABORTED: "aborted", - GoalStatus.STATUS_CANCELED: "canceled", - GoalStatus.STATUS_ACCEPTED: "accepted", - GoalStatus.STATUS_CANCELING: "canceling", - GoalStatus.STATUS_EXECUTING: "executing", - GoalStatus.STATUS_UNKNOWN: "unknown", - }.get(status, "unknown") - - result = self.result_future.result() - - self.destroy_client() - if result.status == GoalStatus.STATUS_SUCCEEDED: - msg = f"Result succeeded: {result.result}" - self.get_logger().info(msg) - return msg - else: - str_status = parse_status(result.status) - error_code_str = self.parse_error_code(result.result.error_code) - msg = f"Result {str_status}, because of: error_code={result.result.error_code}({error_code_str}), error_msg={result.result.error_msg}" - self.get_logger().info(msg) - return msg - - def parse_error_code(self, code: int) -> str: - code_to_name = { - v: k for k, v in vars(self.msg_cls.Result).items() if isinstance(v, int) - } - return code_to_name.get(code, "UNKNOWN") - - def _feedback_callback(self, msg): - self.get_logger().info(f"Received ros2 action feedback: {msg}") - self.action_feedback = msg - - def is_task_complete(self): - if not self.result_future: - # task was cancelled or completed - return True - - result = get_future_result(self.result_future, timeout_sec=0.10) - if result is not None: - self.status = result.status - if self.status != GoalStatus.STATUS_SUCCEEDED: - self.get_logger().debug( - f"Task with failed with status code: {self.status}" - ) - return True - else: - self.get_logger().info("There is no result") - # Timed out, still processing, not complete yet - return False - - self.get_logger().info("Task succeeded!") - return True - - def cancel_task(self) -> Union[str, bool]: - self.get_logger().info("Canceling current task.") - try: - if self.result_future and self.goal_handle: - future = self.goal_handle.cancel_goal_async() - result = get_future_result(future, timeout_sec=1.0) - return "Failed to cancel result" if result is None else True - return True - finally: - self.destroy_client() - - def destroy_client(self): - if self.client: - self.client.destroy() + + @staticmethod + def _get_message_class(msg_type: str) -> Type[Any]: + """Convert message type string to actual message class.""" + return import_message_from_str(msg_type) + + def _verify_publisher_exists(self, topic: str) -> None: + """Verify that at least one publisher exists for the given topic. + + Raises: + ValueError: If no publisher exists for the topic + """ + if not self._node.get_publishers_info_by_topic(topic): + raise ValueError(f"No publisher found for topic: {topic}") + + def __del__(self) -> None: + """Cleanup publishers when object is destroyed.""" + for publisher in self._publishers.values(): + publisher.destroy() diff --git a/src/rai/rai/tools/ros/utils.py b/src/rai/rai/tools/ros/utils.py index 7756a6c0..f18a8071 100644 --- a/src/rai/rai/tools/ros/utils.py +++ b/src/rai/rai/tools/ros/utils.py @@ -14,7 +14,7 @@ import base64 -from typing import Optional, Type, Union, cast +from typing import Optional, Tuple, Type, Union, cast import cv2 import numpy as np @@ -111,7 +111,7 @@ def wait_for_message( *, qos_profile: Union[QoSProfile, int] = 1, time_to_wait=-1, -): +) -> Tuple[bool, Optional[object]]: """ Wait for the next incoming message. From 581aef338f7f7d7bb9f2eae091c2f59427065a41 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Tue, 14 Jan 2025 22:42:59 +0100 Subject: [PATCH 14/14] feat: implement ROS2ServiceAPI --- src/rai/rai/communication/ros2/api.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/rai/rai/communication/ros2/api.py b/src/rai/rai/communication/ros2/api.py index 61ed6759..dc74e896 100644 --- a/src/rai/rai/communication/ros2/api.py +++ b/src/rai/rai/communication/ros2/api.py @@ -93,6 +93,15 @@ def build_ros2_msg(msg_type: str, msg_args: Dict[str, Any]) -> object: return msg +def build_ros2_service_request( + service_type: str, service_request_args: Dict[str, Any] +) -> Tuple[object, Type[Any]]: + msg_cls = import_message_from_str(service_type) + msg = msg_cls.Request() + rosidl_runtime_py.set_message.set_message_fields(msg, service_request_args) + return msg, msg_cls # type: ignore + + class ROS2TopicAPI: """Handles ROS2 topic operations including publishing and subscribing to messages. @@ -255,3 +264,13 @@ def __del__(self) -> None: """Cleanup publishers when object is destroyed.""" for publisher in self._publishers.values(): publisher.destroy() + + +class ROS2ServiceAPI: + def __init__(self, node: rclpy.node.Node) -> None: + self.node = node + + def call_service(self, service_name: str, service_type: str, request: Any) -> Any: + srv_msg, srv_cls = build_ros2_service_request(service_type, request) + service_client = self.node.create_client(srv_cls, service_name) # type: ignore + return service_client.call(srv_msg)