From 0719a3e87b3fd1ae69790dcf6aa55a812e7040dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Fri, 27 Oct 2023 22:42:34 +0200 Subject: [PATCH] formatting --- yolov8_ros/yolov8_ros/detect_3d_node.py | 88 +++++++++++++++---------- yolov8_ros/yolov8_ros/yolov8_node.py | 9 +-- 2 files changed, 58 insertions(+), 39 deletions(-) diff --git a/yolov8_ros/yolov8_ros/detect_3d_node.py b/yolov8_ros/yolov8_ros/detect_3d_node.py index efa5e92..2dc180b 100644 --- a/yolov8_ros/yolov8_ros/detect_3d_node.py +++ b/yolov8_ros/yolov8_ros/detect_3d_node.py @@ -13,22 +13,27 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from typing import List, Tuple -import message_filters import numpy as np +from typing import List, Tuple + import rclpy -from cv_bridge import CvBridge -from geometry_msgs.msg import TransformStamped from rclpy.node import Node from rclpy.qos import qos_profile_sensor_data from sensor_msgs.msg import CameraInfo, Image -from tf2_ros import TransformException +from geometry_msgs.msg import TransformStamped + +import message_filters +from cv_bridge import CvBridge from tf2_ros.buffer import Buffer +from tf2_ros import TransformException from tf2_ros.transform_listener import TransformListener -from yolov8_msgs.msg import (BoundingBox3D, Detection, DetectionArray, - KeyPoint3D, KeyPoint3DArray) +from yolov8_msgs.msg import Detection +from yolov8_msgs.msg import DetectionArray +from yolov8_msgs.msg import KeyPoint3D +from yolov8_msgs.msg import KeyPoint3DArray +from yolov8_msgs.msg import BoundingBox3D class Detect3DNode(Node): @@ -69,11 +74,12 @@ def __init__(self) -> None: (self.depth_sub, self.depth_info_sub, self.detections_sub), 10, 0.5) self._synchronizer.registerCallback(self.on_detections) - def on_detections(self, - depth_msg: Image, - depth_info_msg: CameraInfo, - detections_msg: DetectionArray, - ) -> None: + def on_detections( + self, + depth_msg: Image, + depth_info_msg: CameraInfo, + detections_msg: DetectionArray, + ) -> None: new_detections_msg = DetectionArray() new_detections_msg.header = detections_msg.header @@ -82,10 +88,10 @@ def on_detections(self, self._pub.publish(new_detections_msg) def process_detections( - self, - depth_msg: Image, - depth_info_msg: CameraInfo, - detections_msg: DetectionArray + self, + depth_msg: Image, + depth_info_msg: CameraInfo, + detections_msg: DetectionArray ) -> List[Detection]: # check if there are detections @@ -122,30 +128,37 @@ def process_detections( return new_detections - def convert_bb_to_3d(self, - depth_image: np.ndarray, - depth_info: CameraInfo, - detection: Detection - ) -> BoundingBox3D: + def convert_bb_to_3d( + self, + depth_image: np.ndarray, + depth_info: CameraInfo, + detection: Detection + ) -> BoundingBox3D: # crop depth image by the 2d BB center_x = int(detection.bbox.center.position.x) center_y = int(detection.bbox.center.position.y) size_x = int(detection.bbox.size.x) size_y = int(detection.bbox.size.y) - u_min, u_max = max(center_x - size_x // 2, 0), min(center_x + size_x // 2, depth_image.shape[1] - 1) - v_min, v_max = max(center_y - size_y // 2, 0), min(center_y + size_y // 2, depth_image.shape[0] - 1) - roi = depth_image[v_min:v_max, u_min:u_max] / self.depth_image_units_divisor # convert to meters + + u_min = max(center_x - size_x // 2, 0) + u_max = min(center_x + size_x // 2, depth_image.shape[1] - 1) + v_min = max(center_y - size_y // 2, 0) + v_max = min(center_y + size_y // 2, depth_image.shape[0] - 1) + + roi = depth_image[v_min:v_max, u_min:u_max] / \ + self.depth_image_units_divisor # convert to meters if not np.any(roi): return None # find the z coordinate on the 3D BB - bb_center_z_coord = depth_image[int(center_y)][int(center_x)] / self.depth_image_units_divisor + bb_center_z_coord = depth_image[int(center_y)][int( + center_x)] / self.depth_image_units_divisor z_diff = np.abs(roi - bb_center_z_coord) mask_z = z_diff <= self.maximum_detection_threshold - mask_z = z_diff <= self.maximum_detection_threshold if not np.any(mask_z): return None + roi_threshold = roi[mask_z] z_min, z_max = np.min(roi_threshold), np.max(roi_threshold) z = (z_max + z_min) / 2 @@ -169,14 +182,16 @@ def convert_bb_to_3d(self, return msg - def convert_keypoints_to_3d(self, - depth_image: np.ndarray, - depth_info: CameraInfo, - detection: Detection - ) -> KeyPoint3DArray: + def convert_keypoints_to_3d( + self, + depth_image: np.ndarray, + depth_info: CameraInfo, + detection: Detection + ) -> KeyPoint3DArray: # build an array of 2d keypoints - keypoints_2d = np.array([[p.point.x, p.point.y] for p in detection.keypoints.data], dtype=np.int16) + keypoints_2d = np.array([[p.point.x, p.point.y] + for p in detection.keypoints.data], dtype=np.int16) u = np.array(keypoints_2d[:, 1]).clip(0, depth_info.height - 1) v = np.array(keypoints_2d[:, 0]).clip(0, depth_info.width - 1) @@ -186,7 +201,8 @@ def convert_keypoints_to_3d(self, px, py, fx, fy = k[2], k[5], k[0], k[4] x = z * (v - px) / fx y = z * (u - py) / fy - points_3d = np.dstack([x, y, z]).reshape(-1, 3) / self.depth_image_units_divisor # convert to meters + points_3d = np.dstack([x, y, z]).reshape(-1, 3) / \ + self.depth_image_units_divisor # convert to meters # generate message msg_array = KeyPoint3DArray() @@ -203,7 +219,7 @@ def convert_keypoints_to_3d(self, return msg_array def get_transform(self, frame_id: str) -> Tuple[np.ndarray]: - # transform position from point cloud frame to target_frame + # transform position from image frame to target_frame rotation = None translation = None @@ -296,5 +312,7 @@ def qv_mult(q: np.ndarray, v: np.ndarray) -> np.ndarray: def main(): rclpy.init() - rclpy.spin(Detect3DNode()) + node = Detect3DNode() + rclpy.spin(node) + node.destroy_node() rclpy.shutdown() diff --git a/yolov8_ros/yolov8_ros/yolov8_node.py b/yolov8_ros/yolov8_ros/yolov8_node.py index e5e3a3a..d892651 100644 --- a/yolov8_ros/yolov8_ros/yolov8_node.py +++ b/yolov8_ros/yolov8_ros/yolov8_node.py @@ -77,10 +77,11 @@ def __init__(self) -> None: # services self._srv = self.create_service(SetBool, "enable", self.enable_cb) - def enable_cb(self, - req: SetBool.Request, - res: SetBool.Response - ) -> SetBool.Response: + def enable_cb( + self, + req: SetBool.Request, + res: SetBool.Response + ) -> SetBool.Response: self.enable = req.data res.success = True return res