Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Add benchmarking utilities #403

Draft
wants to merge 1 commit into
base: development
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/rai/rai/tools/ros/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from rclpy.node import Node
from tf2_geometry_msgs import do_transform_pose

from rai.communication.ros2.connectors import ROS2ARIConnector
from rai.tools.utils import TF2TransformFetcher
from rai_interfaces.srv import ManipulatorMoveTo

Expand Down Expand Up @@ -156,9 +157,13 @@ class GetObjectPositionsTool(BaseTool):
node: Node
get_grabbing_point_tool: GetGrabbingPointTool

def __init__(self, node: Node, **kwargs):
def __init__(self, connector: ROS2ARIConnector, node: Node, **kwargs):
super(GetObjectPositionsTool, self).__init__(
node=node, get_grabbing_point_tool=GetGrabbingPointTool(node=node), **kwargs
node=node,
get_grabbing_point_tool=GetGrabbingPointTool(
connector=connector, node=node
),
**kwargs,
)

args_schema: Type[GetObjectPositionsToolInput] = GetObjectPositionsToolInput
Expand Down
Empty file added src/rai_benchmarks/__init__.py
Empty file.
26 changes: 26 additions & 0 deletions src/rai_benchmarks/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import rclpy
from manager import RaiBenchmarkManager
from rclpy.executors import MultiThreadedExecutor
from scenarios.longest_object import LongestObject
from scenarios.move_to_the_left import MoveToTheLeft
from scenarios.place_on_top import PlaceOnTop
from scenarios.replace_types import ReplaceTypes


def main(args=None):
rclpy.init(args=args)

manager = RaiBenchmarkManager(
[PlaceOnTop, LongestObject, MoveToTheLeft, ReplaceTypes], list(range(4))
)

executor = MultiThreadedExecutor(2)
executor.add_node(manager)
executor.spin()

manager.destroy_node()
rclpy.shutdown()


if __name__ == "__main__":
main()
182 changes: 182 additions & 0 deletions src/rai_benchmarks/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import random
import time
from threading import Thread

from gazebo_msgs.srv import DeleteEntity, SpawnEntity
from geometry_msgs.msg import Point, Quaternion
from langchain_core.messages import HumanMessage
from rclpy.node import Node
from rclpy.task import Future
from scenarios.scenario_base import ScenarioBase
from tf2_ros import Buffer, TransformListener

from rai.agents.conversational_agent import create_conversational_agent
from rai.communication.ros2.connectors import ROS2ARIConnector
from rai.node import RaiBaseNode
from rai.tools.ros2.topics import GetROS2ImageTool
from rai.tools.ros.manipulation import GetObjectPositionsTool, MoveToPointTool
from rai.tools.ros.native import Ros2GetTopicsNamesAndTypesTool
from rai.utils.model_initialization import get_llm_model
from rai_interfaces.srv import ManipulatorMoveTo


class ScenarioManager(Node):
"""
A class responsible for playing the scenarios
"""

def __init__(self, scenario_types, seeds=[]):
"""
Initializes the ScenarioManager

Args:
scenario_types: A list of scenario classes to play
seeds: A list of seeds to use for each scenario
"""
super().__init__("scenario_manager")
self.scenario_types = scenario_types
self.seeds = seeds

self.spawn_client = self.create_client(SpawnEntity, "/spawn_entity")
self.delete_client = self.create_client(DeleteEntity, "/delete_entity")
self.manipulator_client = self.create_client(
ManipulatorMoveTo, "/manipulator_move_to"
)
self.tf2_buffer = Buffer()
self.tf2_listener = TransformListener(self.tf2_buffer, self)

while not self.spawn_client.wait_for_service(timeout_sec=1.0):
self.get_logger().info("service not available, waiting again...")
while not self.delete_client.wait_for_service(timeout_sec=1.0):
self.get_logger().info("service not available, waiting again...")

self.timer = self.create_timer(1.0, self.timer_callback)
self.scenario: ScenarioBase = None
self.current_scenario = 0
self.agent_thread: Thread = None
self.manipulator_ready = False
self.scores = []

def _init_scenario(self):
self.scenario = self.scenario_types[self.current_scenario](
self.spawn_client, self.delete_client, self.manipulator_client, self
)
self.manipulator_ready = False
request = ManipulatorMoveTo.Request()
request.target_pose.pose.orientation = Quaternion(
x=0.923880, y=-0.382683, z=0.0, w=0.0
)
request.target_pose.pose.position = Point(x=0.2, y=0.0, z=0.2)
if self.current_scenario < len(self.seeds):
random.seed(self.seeds[self.current_scenario])
else:
random.seed(42)

def callback(future: Future):
self.manipulator_ready = True
self.scenario.reset()

self.scenario.manipulator_client.call_async(request).add_done_callback(callback)

def _terminate_scenario(self):
self.get_logger().info(f"Scenario terminated with score {self.scores[-1]}")
self.scenario = None
self.tf2_buffer = Buffer()
self.tf2_listener = TransformListener(self.tf2_buffer, self)
if self.current_scenario == len(self.scenario_types) - 1:
self.get_logger().info(
f"All scenarios are completed, with scores: {self.scores}"
)
request = ManipulatorMoveTo.Request()
request.target_pose.pose.orientation = Quaternion(
x=0.923880, y=-0.382683, z=0.0, w=0.0
)
request.target_pose.pose.position = Point(x=0.2, y=0.0, z=0.2)

def callback(future: Future):
self.manipulator_ready = True
self.executor.shutdown()

self.manipulator_client.call_async(request).add_done_callback(callback)
self.timer.cancel()
return
self.current_scenario = (self.current_scenario + 1) % len(self.scenario_types)
self.manipulator_ready = False

def timer_callback(self):
if self.scenario is None:
self._init_scenario()

if not self.manipulator_ready:
return

progress, terminated = self.scenario.step()
if terminated and not (self.agent_thread and self.agent_thread.is_alive()):
self.scores.append(progress)
self._terminate_scenario()
return
if self.agent_thread and not self.agent_thread.is_alive():
self.get_logger().info(
"Agent failed to fulfill the task, terminating the scenario."
)
self.scores.append(progress)
self._terminate_scenario()


class RaiBenchmarkManager(ScenarioManager):
"""
A class responsible for playing the scenarios and running the conversational agent for each scenario
"""

def __init__(self, scenario_types, seeds=[]):
super().__init__(scenario_types, seeds)
self.agent = None

def _init_scenario(self):
super()._init_scenario()

self.rai_node = RaiBaseNode(node_name="manipulation_demo")
self.rai_node.declare_parameter("conversion_ratio", 1.0)

connector = ROS2ARIConnector()
tools = [
GetObjectPositionsTool(
connector=connector,
node=self.rai_node,
target_frame="panda_link0",
source_frame="RGBDCamera5",
camera_topic="/color_image5",
depth_topic="/depth_image5",
camera_info_topic="/color_camera_info5",
),
MoveToPointTool(node=self.rai_node, manipulator_frame="panda_link0"),
GetROS2ImageTool(node=self.rai_node, connector=connector),
Ros2GetTopicsNamesAndTypesTool(node=self.rai_node),
]

llm = get_llm_model(model_type="complex_model")

system_prompt = """
You are a robotic arm with interfaces to detect and manipulate objects.
Here are the coordinates information:
x - front to back (positive is forward)
y - left to right (positive is right)
z - up to down (positive is up)

Before starting the task, make sure to grab the camera image to understand the environment.
"""

self.agent = create_conversational_agent(
llm=llm,
tools=tools,
system_prompt=system_prompt,
)

def run_agent():
time.sleep(1)
self.agent.invoke(
{"messages": [HumanMessage(content=self.scenario.get_prompt())]}
)["messages"][-1].pretty_print()

self.agent_thread = Thread(target=run_agent)
self.agent_thread.start()
123 changes: 123 additions & 0 deletions src/rai_benchmarks/scenarios/longest_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import rclpy
from geometry_msgs.msg import Point, PoseStamped, Quaternion
from rclpy.client import Client
from rclpy.node import Node
from rclpy.task import Future
from scenarios.scenario_base import ScenarioBase

from rai_interfaces.srv import ManipulatorMoveTo


class LongestObject(ScenarioBase):
def __init__(
self,
spawn_client: Client,
delete_client: Client,
manipulator_client: Client,
node: Node,
):
super().__init__(spawn_client, delete_client, manipulator_client, node)

def get_prompt(self):
return "Put the longest object from the table into the toy box."

def reset(self):
super().reset()

prefabs = ["apple", "yellow_cube", "blue_cube", "carrot"]
self.spawn_entities_in_random_positions(prefabs, prefabs)

pose = PoseStamped()
pose.header.frame_id = "world"
pose.pose.position = Point(x=0.4, y=-0.5, z=0.1)
pose.pose.orientation = Quaternion(x=0.0, y=0.0, z=0.0, w=1.0)
pose_transformed = self.node.tf2_buffer.transform(
pose, "odom", timeout=rclpy.time.Duration(seconds=5.0)
).pose
self.spawn_entity("toy_box", "toy_box", pose_transformed)

def calculate_progress(self):
if len(self.entities) == 0:
return 0.0

carrot_position = self.pose_transformed(self.get_entity_pose("carrot")).position
toy_box_position = self.pose_transformed(
self.get_entity_pose("toy_box")
).position

def distance(a, b):
return ((a.x - b.x) ** 2 + (a.y - b.y) ** 2 + (a.z - b.z) ** 2) ** 0.5

return max(0.0, 1.0 - 3.0 * distance(carrot_position, toy_box_position))

def step(self):
if len(self.entities) == 0:
return 0.0, False

progress = self.calculate_progress()

return progress, progress >= 0.5


class LongestObjectAuto(LongestObject):
def __init__(
self,
spawn_client: Client,
delete_client: Client,
manipulator_client: Client,
node: Node,
):
self.manipulator_busy = False

super().__init__(spawn_client, delete_client, manipulator_client, node)

def reset(self):
super().reset()

self.manipulator_busy = False
self.manipulator_queue = []

def place_on_top(self, bot_object: str, top_object: str):
pose = self.get_entity_pose(top_object)
pose.position.z += 0.1

req = ManipulatorMoveTo.Request()
req.initial_gripper_state = True
req.target_pose.pose = self.pose_transformed(pose)
req.final_gripper_state = False
self.manipulator_queue.append(req)

pose = self.get_entity_pose(bot_object)
pose.position.z += 0.2
req = ManipulatorMoveTo.Request()
req.initial_gripper_state = False
req.target_pose.pose = self.pose_transformed(pose)
req.final_gripper_state = True
self.manipulator_queue.append(req)

def move_callback(self, future: Future):
result = future.result()
if result.success:
self.node.get_logger().debug("Move performed")
else:
self.node.get_logger().error("Failed to perform move")
self.manipulator_busy = False

def step(self):
if len(self.entities) == 0:
return 0.0, False

progress = self.calculate_progress()

if not self.manipulator_busy:
if len(self.manipulator_queue) == 0 and progress < 0.8:
self.place_on_top("toy_box", "carrot")

if len(self.manipulator_queue) > 0:
req = self.manipulator_queue.pop(0)
self.manipulator_busy = True
self.manipulator_client.call_async(req).add_done_callback(
self.move_callback
)

return progress, progress >= 0.8 and not self.manipulator_busy
Loading