diff --git a/CMakeLists.txt b/CMakeLists.txt index e23689e215..779a36205f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,7 +24,7 @@ catkin_package(CATKIN_DEPENDS message_runtime) ############# install(PROGRAMS - nodes/mask_rcnn_node + scripts/mask_rcnn_node.py DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} ) diff --git a/README.md b/README.md index 450ad16f51..fdc7daf6c7 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ Most of core algorithm code was based on [Mask R-CNN implementation by Matterpor ## Training This repository doesn't contain code for training Mask R-CNN network model. -If you want to train the model on youer own class definition or dataset, try it on [the upstream reposity](https://github.com/matterport/Mask_RCNN) and give the result weight to `model_path` parameter. +If you want to train the model on your own class definition or dataset, try it on [the upstream reposity](https://github.com/matterport/Mask_RCNN) and give the result weight to `model_path` parameter. ## Requirements @@ -70,13 +70,12 @@ If you want to train the model on youer own class definition or dataset, try it There is a simple example launch file using [RGB-D SLAM Dataset](https://vision.in.tum.de/data/datasets/rgbd-dataset/download). ~~~bash -$ cd mask_rcnn_ros/examples -$ ./download_example_bag.sh -$ roslaunch example.launch +$ ./scripts/download_freiburg3_rgbd_example_bag.sh +$ roslaunch mask_rcnn_ros freiburg3_rgbd_example.launch ~~~ Then RViz window will appear and show result like following: ![example1](doc/mask_r-cnn_1.png) -![example2](doc/mask_r-cnn_2.png) \ No newline at end of file +![example2](doc/mask_r-cnn_2.png) diff --git a/bags/.placeholder b/bags/.placeholder new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/example.launch b/launch/freiburg3_rgbd_example.launch similarity index 66% rename from examples/example.launch rename to launch/freiburg3_rgbd_example.launch index 1d90222c1a..e11316c2ba 100644 --- a/examples/example.launch +++ b/launch/freiburg3_rgbd_example.launch @@ -1,11 +1,11 @@ - + + args="-l $(find mask_rcnn_ros)/bags/rgbd_dataset_freiburg3_long_office_household.bag" /> - + diff --git a/requirements.txt b/requirements.txt index 147e460de5..67bd209465 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ h5py==2.7.0 Keras==2.1.2 numpy==1.13.3 rviz==1.12.13 +opencv-python==3.4.0.12 scikit-image==0.13.0 scikit-learn==0.19.1 scipy==0.19.1 diff --git a/examples/example.rviz b/rviz/mask_rcnn_ros.rviz similarity index 100% rename from examples/example.rviz rename to rviz/mask_rcnn_ros.rviz diff --git a/examples/download_example_bag.sh b/scripts/download_freiburg3_rgbd_example_bag.sh similarity index 61% rename from examples/download_example_bag.sh rename to scripts/download_freiburg3_rgbd_example_bag.sh index bf8814ff1c..9d4296b6c1 100755 --- a/examples/download_example_bag.sh +++ b/scripts/download_freiburg3_rgbd_example_bag.sh @@ -1,2 +1,2 @@ !/bin/sh -wget https://vision.in.tum.de/rgbd/dataset/freiburg3/rgbd_dataset_freiburg3_long_office_household.bag +wget https://vision.in.tum.de/rgbd/dataset/freiburg3/rgbd_dataset_freiburg3_long_office_household.bag -P bags diff --git a/nodes/mask_rcnn_node b/scripts/mask_rcnn_node.py similarity index 91% rename from nodes/mask_rcnn_node rename to scripts/mask_rcnn_node.py index 0bbc136ee8..baacc9edae 100755 --- a/nodes/mask_rcnn_node +++ b/scripts/mask_rcnn_node.py @@ -1,7 +1,6 @@ #!/usr/bin/env python import os import threading -from Queue import Queue import numpy as np import cv2 @@ -9,7 +8,6 @@ import rospy from sensor_msgs.msg import Image from sensor_msgs.msg import RegionOfInterest -from std_msgs.msg import UInt8MultiArray from mask_rcnn_ros import coco from mask_rcnn_ros import utils @@ -48,6 +46,7 @@ class InferenceConfig(coco.CocoConfig): GPU_COUNT = 1 IMAGES_PER_GPU = 1 + class MaskRCNNNode(object): def __init__(self): self._cv_bridge = CvBridge() @@ -80,8 +79,8 @@ def __init__(self): def run(self): self._result_pub = rospy.Publisher('~result', Result, queue_size=1) vis_pub = rospy.Publisher('~visualization', Image, queue_size=1) - sub = rospy.Subscriber('~input', Image, - self._image_callback, queue_size=1) + rospy.Subscriber('~input', Image, + self._image_callback, queue_size=1) rate = rospy.Rate(self._publish_rate) while not rospy.is_shutdown(): @@ -104,9 +103,7 @@ def run(self): # Visualize results if self._visualization: - vis_image = self._visualize(result, np_image) - cv_result = np.zeros(shape=vis_image.shape, dtype=np.uint8) - cv2.convertScaleAbs(vis_image, cv_result) + cv_result = self._visualize_cv(result, np_image) image_msg = self._cv_bridge.cv2_to_imgmsg(cv_result, 'bgr8') vis_pub.publish(image_msg) @@ -162,17 +159,28 @@ def _visualize(self, result, image): result = result.reshape((int(h), int(w), 3)) return result + def _visualize_cv(self, result, image): + + image = visualize.display_instances_cv(image, result['rois'], result['masks'], + result['class_ids'], CLASS_NAMES, + result['scores'], + class_colors=self._class_colors) + + return image + def _image_callback(self, msg): rospy.logdebug("Get an image") if self._msg_lock.acquire(False): self._last_msg = msg self._msg_lock.release() + def main(): rospy.init_node('mask_rcnn') node = MaskRCNNNode() node.run() + if __name__ == '__main__': main() diff --git a/src/mask_rcnn_ros/visualize.py b/src/mask_rcnn_ros/visualize.py index 16cca66a13..6acaea942a 100644 --- a/src/mask_rcnn_ros/visualize.py +++ b/src/mask_rcnn_ros/visualize.py @@ -16,6 +16,7 @@ import matplotlib.patches as patches import matplotlib.lines as lines from matplotlib.patches import Polygon +import cv2 import IPython.display import utils @@ -150,6 +151,61 @@ def display_instances(image, boxes, masks, class_ids, class_names, #plt.show() +def display_instances_cv(image, boxes, masks, class_ids, class_names, + scores=None, class_colors=None, alpha=0.7): + """ + boxes: [num_instance, (y1, x1, y2, x2, class_id)] in image coordinates. + masks: [height, width, num_instances] + class_ids: [num_instances] + class_names: list of class names of the dataset + scores: (optional) confidence scores for each box + class_colors: a list mapping class ids to their colors + alpha: the amount of transparency of the mask overlay + """ + # Number of instances + n = boxes.shape[0] + if n: + assert boxes.shape[0] == masks.shape[-1] == class_ids.shape[0] + + # Generate random colors + if class_colors is None: + colors = random_colors(n) + + for i in range(n): + class_id = class_ids[i] + if class_colors is None: + color = colors[i] + else: + color = class_colors[class_id] + + # Transform class colors to BGR and rescale [0-255] for OpenCv + bgr_color = tuple(c*255 for c in color[::-1]) + + # Draw bounding boxes + if not np.any(boxes[i]): + # Skip this instance. Has no bbox. Likely lost in image cropping. + continue + y1, x1, y2, x2 = boxes[i] + cv2.rectangle(image, (x1, y1), (x2, y2), color=bgr_color, thickness=2) + + # Draw transparent mask + overlay = image.copy() + mask = masks[:, :, i] + __, thresh = cv2.threshold(mask, 0.5, 1, cv2.THRESH_BINARY) + _, contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + cv2.drawContours(image, contours, -1, color=bgr_color, thickness=cv2.FILLED) + cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image) + + # Draw text label + score = scores[i] if scores is not None else None + label = class_names[class_id] + caption = "{} {:.3f}".format(label, score) if score else label + cv2.putText(image, caption, (x1, y1 + 12), fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=0.5, + color=(255, 255, 255)) + + return image + + def draw_rois(image, rois, refined_rois, mask, class_ids, class_names, limit=10): """ anchors: [n, (y1, x1, y2, x2)] list of anchors in image coordinates.