Skip to content

Commit

Permalink
Merge pull request #395 from hdefazio/dev/hannah_test
Browse files Browse the repository at this point in the history
Fix image resizing in object detector node, Add drawing joints to demo UI
  • Loading branch information
hdefazio authored Apr 18, 2024
2 parents 5028435 + 44ed9da commit 146cddf
Show file tree
Hide file tree
Showing 11 changed files with 396 additions and 146 deletions.
Empty file.
47 changes: 47 additions & 0 deletions angel_system/object_detection/yolov8_detect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np

from ultralytics import YOLO as YOLOv8


def predict_hands(hand_model: YOLOv8, img0: np.array, device: str, imgsz: int) -> tuple:
"""Predict hands using a YOLOv8 hand model and update the labels to be
hand(left) and hand(right)
"""
width, height = img0.shape[:2]
hands_preds = hand_model.predict(
source=img0,
conf=0.1,
imgsz=imgsz,
device=device,
verbose=False
)[0] # list of length=num images

hand_centers = [center.xywh.tolist()[0][0] for center in hands_preds.boxes][:2]
hands_label = []

# Update the hand label to left and right specific labels
if len(hand_centers) == 2:
if hand_centers[0] > hand_centers[1]:
hands_label.append("hand (right)")
hands_label.append("hand (left)")
elif hand_centers[0] <= hand_centers[1]:
hands_label.append("hand (left)")
hands_label.append("hand (right)")
elif len(hand_centers) == 1:
if hand_centers[0] > width//2:
hands_label.append("hand (right)")
elif hand_centers[0] <= width//2:
hands_label.append("hand (left)")

boxes, labels, confs = [], [], []

for bbox, hand_cid in zip(hands_preds.boxes, hands_label):
xyxy_hand = bbox.xyxy.tolist()[0]

conf = bbox.conf.item()

boxes.append(xyxy_hand)
labels.append(hand_cid)
confs.append(conf)

return boxes, labels, confs
2 changes: 1 addition & 1 deletion python-tpl/yolov7
Submodule yolov7 updated 1 files
+0 −39 yolov7/detect_ptg.py
1 change: 1 addition & 0 deletions ros/angel_msgs/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ set( message_files
msg/HeadsetPoseData.msg
msg/InterpretedAudioUserIntent.msg
msg/InterpretedAudioUserEmotion.msg
msg/JointKeypoints.msg
msg/ObjectDetection2d.msg
msg/ObjectDetection2dSet.msg
msg/ObjectDetection3dSet.msg
Expand Down
6 changes: 5 additions & 1 deletion ros/angel_msgs/msg/JointKeypoints.msg
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
# Message that contains a snapshot of the patient joint poses
#

# Header frame_id should indicate the source these detections were predicted
# over.
std_msgs/Header header
frame_id

# Timestamp of the source image these predictions pertain to.
builtin_interfaces/Time source_stamp

# List of joints
float64[] keypoints
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from rclpy.node import Node, ParameterDescriptor, Parameter
from sensor_msgs.msg import Image

from yolov7.detect_ptg import load_model, predict_image, predict_hands
from yolov7.detect_ptg import load_model, predict_image
from angel_system.object_detection.yolov8_detect import predict_hands
from yolov7.models.experimental import attempt_load
import yolov7.models.yolo
from yolov7.utils.torch_utils import TracedModel
from ultralytics import YOLO
from ultralytics import YOLO as YOLOv8

from angel_system.utils.event import WaitAndClearEvent
from angel_system.utils.simple_timer import SimpleTimer
Expand All @@ -26,7 +27,7 @@
BRIDGE = CvBridge()


class ObjectHandDetector(Node):
class ObjectAndHandDetector(Node):
"""
ROS node that runs the yolov7 object detector model and outputs
`ObjectDetection2dSet` messages.
Expand All @@ -44,7 +45,7 @@ def __init__(self):
# Required parameter (no defaults)
("image_topic",),
("det_topic",),
("net_checkpoint",),
("object_net_checkpoint",),
("hand_net_checkpoint",),
##################################
# Defaulted parameters
Expand All @@ -59,16 +60,14 @@ def __init__(self):
# If we should enable additional logging to the info level
# about when we receive and process data.
("enable_time_trace_logging", False),
("image_resize", False)
],
)
self._image_topic = param_values["image_topic"]
self._det_topic = param_values["det_topic"]
self._model_ckpt_fp = Path(param_values["net_checkpoint"])

self._object_model_ckpt_fp = Path(param_values["object_net_checkpoint"])
self._hand_model_chpt_fp = Path(param_values["hand_net_checkpoint"])

self._ensure_image_resize = param_values["image_resize"]

self._inference_img_size = param_values["inference_img_size"]
self._det_conf_thresh = param_values["det_conf_threshold"]
self._iou_thr = param_values["iou_threshold"]
Expand All @@ -78,18 +77,18 @@ def __init__(self):

self._enable_trace_logging = param_values["enable_time_trace_logging"]

# Model
self.model: Union[yolov7.models.yolo.Model, TracedModel]
if not self._model_ckpt_fp.is_file():
# Object Model
self.object_model: Union[yolov7.models.yolo.Model, TracedModel]
if not self._object_model_ckpt_fp.is_file():
raise ValueError(
f"Model checkpoint file did not exist: {self._model_ckpt_fp}"
f"Model checkpoint file did not exist: {self._object_model_ckpt_fp}"
)
(self.device, self.model, self.stride, self.imgsz) = load_model(
str(self._cuda_device_id), self._model_ckpt_fp, self._inference_img_size
(self.device, self.object_model, self.stride, self.imgsz) = load_model(
str(self._cuda_device_id), self._object_model_ckpt_fp, self._inference_img_size
)
log.info(
f"Loaded model with classes:\n"
+ "\n".join(f'\t- "{n}"' for n in self.model.names)
+ "\n".join(f'\t- "{n}"' for n in self.object_model.names)
)

# Single slot for latest image message to process detection over.
Expand All @@ -111,16 +110,16 @@ def __init__(self):
callback_group=MutuallyExclusiveCallbackGroup(),
)

self.hand_model = YOLO(self._hand_model_chpt_fp)
self.hand_model = YOLOv8(self._hand_model_chpt_fp)

if not self._no_trace:
self.model = TracedModel(self.model, self.device, self._inference_img_size)
self.object_model = TracedModel(self.object_model, self.device, self._inference_img_size)

self.half = half = (
self.device.type != "cpu"
) # half precision only supported on CUDA
if half:
self.model.half() # to FP16
self.object_model.half() # to FP16

self._rate_tracker = RateTracker()
log.info("Detector initialized")
Expand Down Expand Up @@ -173,10 +172,10 @@ def rt_loop(self):
log.info("Runtime loop starting")
enable_trace_logging = self._enable_trace_logging

if "background" in self.model.names:
label_vector = self.model.names[1:] # remove background label
if "background" in self.object_model.names:
label_vector = self.object_model.names[1:] # remove background label
else:
label_vector = self.model.names
label_vector = self.object_model.names

label_vector.append("hand (left)")
label_vector.append("hand (right)")
Expand All @@ -199,51 +198,48 @@ def rt_loop(self):

if enable_trace_logging:
log.info(f"[rt-loop] Processing image TS={image.header.stamp}")

# Convert ROS img msg to CV2 image
img0 = BRIDGE.imgmsg_to_cv2(image, desired_encoding="bgr8")


print(f"img0: {img0.shape}")
print(f"img0 type: {type(img0)}")

# width, height = self._inference_img_size
if self._ensure_image_resize:
img0 = cv2.resize(img0, dsize=(1280, 720), interpolation=cv2.INTER_CUBIC)

print(f"img0: {img0.shape}")

msg = ObjectDetection2dSet()
msg.header.stamp = self.get_clock().now().to_msg()
msg.header.frame_id = image.header.frame_id
msg.source_stamp = image.header.stamp
msg.label_vec[:] = label_vector

print(f"model names: {self.model.names}")
print(f"object model names: {self.object_model.names}")

n_dets = 0

dflt_conf_vec = np.zeros(n_classes, dtype=np.float64)

hand_boxes, hand_labels, hand_confs = predict_hands(hand_model=self.hand_model,
img0=img0,
device=self.device)
# Detect hands
hand_boxes, hand_labels, hand_confs = predict_hands(
hand_model=self.hand_model,
img0=img0,
device=self.device,
imgsz=self._inference_img_size
)

hand_classids = [hand_cid_label_dict[label] for label in hand_labels]


# Detect objects
objcet_boxes, object_confs, objects_classids = predict_image(
img0,
self.device,
self.model,
self.stride,
self.imgsz,
self.half,
False,
self._det_conf_thresh,
self._iou_thr,
None,
self._agnostic_nms,
)
img0,
self.device,
self.object_model,
self.stride,
self.imgsz,
self.half,
False,
self._det_conf_thresh,
self._iou_thr,
None,
self._agnostic_nms,
)

objcet_boxes.extend(hand_boxes)
object_confs.extend(hand_confs)
Expand Down Expand Up @@ -285,7 +281,7 @@ def destroy_node(self):
# - 1 known subscriber which has their own group
# - 1 for default group
# - 1 for publishers
main = make_default_main(ObjectHandDetector, multithreaded_executor=3)
main = make_default_main(ObjectAndHandDetector, multithreaded_executor=3)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion ros/angel_system_nodes/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"object_detector_with_descriptors_v2 = angel_system_nodes.object_detection.object_detector_with_descriptors_v2:main",
"object_detection_yolo_v7 = angel_system_nodes.object_detection.yolov7_object_detector:main",
"object_detection_filter = angel_system_nodes.object_detection.object_detection_filter:main",
"object_hand_detector = angel_system_nodes.object_detection.object_hand_detection:main",
"object_and_hand_detector = angel_system_nodes.object_detection.object_and_hand_detection:main",
# Activity Classification
"activity_classifier_tcn = angel_system_nodes.activity_classification.activity_classifier_tcn:main",
"activity_detector = angel_system_nodes.activity_classification.activity_detector:main",
Expand Down
Loading

0 comments on commit 146cddf

Please sign in to comment.