Skip to content

Commit

Permalink
refactor: manipulation toolset (#280)
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejmajek authored Oct 18, 2024
1 parent b1c2fba commit 0333ea2
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 14 deletions.
15 changes: 7 additions & 8 deletions src/rai/rai/tools/ros/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ class MoveToPointTool(BaseTool):
node: Node
client: Client

manipulator_frame: str = Field(
default="panda_link0", description="Manipulator frame"
)
manipulator_frame: str = Field(..., description="Manipulator frame")
min_z: float = Field(default=0.135, description="Minimum z coordinate [m]")
calibration_x: float = Field(default=0.071, description="Calibration x [m]")
calibration_y: float = Field(default=-0.025, description="Calibration y [m]")
Expand All @@ -74,13 +72,14 @@ class MoveToPointTool(BaseTool):

args_schema: Type[MoveToPointToolInput] = MoveToPointToolInput

def __init__(self, node: Node):
def __init__(self, node: Node, **kwargs):
super().__init__(
node=node,
client=node.create_client(
ManipulatorMoveTo,
"/manipulator_move_to",
),
**kwargs,
)

def _run(
Expand All @@ -106,7 +105,7 @@ def _run(

pose_stamped.pose.position.z = np.max(
[pose_stamped.pose.position.z, self.min_z]
) # avoid hitting the table
)

request = ManipulatorMoveTo.Request()
request.target_pose = pose_stamped
Expand Down Expand Up @@ -144,13 +143,13 @@ class GetObjectPositionsToolInput(BaseModel):
class GetObjectPositionsTool(BaseTool):
name: str = "get_object_positions"
description: str = (
"Retrieve the positions of all objects of a specified type within the manipulator's frame of reference. "
"Retrieve the positions of all objects of a specified type in the target frame. "
"This tool provides accurate positional data but does not distinguish between different colors of the same object type. "
"While position detection is reliable, please note that object classification may occasionally be inaccurate."
)

target_frame: str # frame of the manipulator
source_frame: str # frame of the camera
target_frame: str
source_frame: str
camera_topic: str # rgb camera topic
depth_topic: str
camera_info_topic: str # rgb camera info topic
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List, Optional, Type
from typing import Any, List, Optional, Sequence, Type

import cv2
import numpy as np
Expand Down Expand Up @@ -177,7 +177,9 @@ def _run(
return "", {"segmentations": ret}


def depth_to_point_cloud(depth_image, fx, fy, cx, cy):
def depth_to_point_cloud(
depth_image: np.ndarray, fx: float, fy: float, cx: float, cy: float
):
height, width = depth_image.shape

# Create grid of pixel coordinates
Expand Down Expand Up @@ -233,13 +235,15 @@ def _process_mask(
self,
mask_msg: sensor_msgs.msg.Image,
depth_msg: sensor_msgs.msg.Image,
intrinsic,
intrinsic: Sequence[float],
depth_to_meters_ratio: float,
):
mask = convert_ros_img_to_ndarray(mask_msg)
binary_mask = np.where(mask == 255, 1, 0)
depth = convert_ros_img_to_ndarray(depth_msg)
masked_depth_image = np.zeros_like(depth, dtype=np.float32)
masked_depth_image[binary_mask == 1] = depth[binary_mask == 1]
masked_depth_image = masked_depth_image * depth_to_meters_ratio

pcd = depth_to_point_cloud(
masked_depth_image, intrinsic[0], intrinsic[1], intrinsic[2], intrinsic[3]
Expand All @@ -248,6 +252,7 @@ def _process_mask(
# TODO: Filter out outliers
points = pcd

# https://github.com/ycheng517/tabletop-handybot/blob/6d401e577e41ea86529d091b406fbfc936f37a8d/tabletop_handybot/tabletop_handybot/tabletop_handybot_node.py#L413-L424
grasp_z = points[:, 2].max()
near_grasp_z_points = points[points[:, 2] > grasp_z - 0.008]
xy_points = near_grasp_z_points[:, :2]
Expand All @@ -265,8 +270,6 @@ def _process_mask(

# Calculate full 3D centroid for OBJECT
centroid = np.mean(points, axis=0)
# TODO : change offset to be dependant on the height of the object
centroid[2] += 0.1 # Added a small offset to prevent gripper collision
return centroid, gripper_rotation

def _run(
Expand Down Expand Up @@ -315,6 +318,13 @@ def _run(
assert resolved is not None
rets = []
for mask_msg in resolved.masks:
rets.append(self._process_mask(mask_msg, depth_msg, intrinsic))
rets.append(
self._process_mask(
mask_msg,
depth_msg,
intrinsic,
depth_to_meters_ratio=conversion_ratio,
)
)

return rets

0 comments on commit 0333ea2

Please sign in to comment.