Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
mgonzs13 committed Oct 27, 2023
1 parent 1b696dc commit 0719a3e
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 39 deletions.
88 changes: 53 additions & 35 deletions yolov8_ros/yolov8_ros/detect_3d_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,27 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

from typing import List, Tuple

import message_filters
import numpy as np
from typing import List, Tuple

import rclpy
from cv_bridge import CvBridge
from geometry_msgs.msg import TransformStamped
from rclpy.node import Node
from rclpy.qos import qos_profile_sensor_data
from sensor_msgs.msg import CameraInfo, Image
from tf2_ros import TransformException
from geometry_msgs.msg import TransformStamped

import message_filters
from cv_bridge import CvBridge
from tf2_ros.buffer import Buffer
from tf2_ros import TransformException
from tf2_ros.transform_listener import TransformListener

from yolov8_msgs.msg import (BoundingBox3D, Detection, DetectionArray,
KeyPoint3D, KeyPoint3DArray)
from yolov8_msgs.msg import Detection
from yolov8_msgs.msg import DetectionArray
from yolov8_msgs.msg import KeyPoint3D
from yolov8_msgs.msg import KeyPoint3DArray
from yolov8_msgs.msg import BoundingBox3D


class Detect3DNode(Node):
Expand Down Expand Up @@ -69,11 +74,12 @@ def __init__(self) -> None:
(self.depth_sub, self.depth_info_sub, self.detections_sub), 10, 0.5)
self._synchronizer.registerCallback(self.on_detections)

def on_detections(self,
depth_msg: Image,
depth_info_msg: CameraInfo,
detections_msg: DetectionArray,
) -> None:
def on_detections(
self,
depth_msg: Image,
depth_info_msg: CameraInfo,
detections_msg: DetectionArray,
) -> None:

new_detections_msg = DetectionArray()
new_detections_msg.header = detections_msg.header
Expand All @@ -82,10 +88,10 @@ def on_detections(self,
self._pub.publish(new_detections_msg)

def process_detections(
self,
depth_msg: Image,
depth_info_msg: CameraInfo,
detections_msg: DetectionArray
self,
depth_msg: Image,
depth_info_msg: CameraInfo,
detections_msg: DetectionArray
) -> List[Detection]:

# check if there are detections
Expand Down Expand Up @@ -122,30 +128,37 @@ def process_detections(

return new_detections

def convert_bb_to_3d(self,
depth_image: np.ndarray,
depth_info: CameraInfo,
detection: Detection
) -> BoundingBox3D:
def convert_bb_to_3d(
self,
depth_image: np.ndarray,
depth_info: CameraInfo,
detection: Detection
) -> BoundingBox3D:

# crop depth image by the 2d BB
center_x = int(detection.bbox.center.position.x)
center_y = int(detection.bbox.center.position.y)
size_x = int(detection.bbox.size.x)
size_y = int(detection.bbox.size.y)
u_min, u_max = max(center_x - size_x // 2, 0), min(center_x + size_x // 2, depth_image.shape[1] - 1)
v_min, v_max = max(center_y - size_y // 2, 0), min(center_y + size_y // 2, depth_image.shape[0] - 1)
roi = depth_image[v_min:v_max, u_min:u_max] / self.depth_image_units_divisor # convert to meters

u_min = max(center_x - size_x // 2, 0)
u_max = min(center_x + size_x // 2, depth_image.shape[1] - 1)
v_min = max(center_y - size_y // 2, 0)
v_max = min(center_y + size_y // 2, depth_image.shape[0] - 1)

roi = depth_image[v_min:v_max, u_min:u_max] / \
self.depth_image_units_divisor # convert to meters
if not np.any(roi):
return None

# find the z coordinate on the 3D BB
bb_center_z_coord = depth_image[int(center_y)][int(center_x)] / self.depth_image_units_divisor
bb_center_z_coord = depth_image[int(center_y)][int(
center_x)] / self.depth_image_units_divisor
z_diff = np.abs(roi - bb_center_z_coord)
mask_z = z_diff <= self.maximum_detection_threshold
mask_z = z_diff <= self.maximum_detection_threshold
if not np.any(mask_z):
return None

roi_threshold = roi[mask_z]
z_min, z_max = np.min(roi_threshold), np.max(roi_threshold)
z = (z_max + z_min) / 2
Expand All @@ -169,14 +182,16 @@ def convert_bb_to_3d(self,

return msg

def convert_keypoints_to_3d(self,
depth_image: np.ndarray,
depth_info: CameraInfo,
detection: Detection
) -> KeyPoint3DArray:
def convert_keypoints_to_3d(
self,
depth_image: np.ndarray,
depth_info: CameraInfo,
detection: Detection
) -> KeyPoint3DArray:

# build an array of 2d keypoints
keypoints_2d = np.array([[p.point.x, p.point.y] for p in detection.keypoints.data], dtype=np.int16)
keypoints_2d = np.array([[p.point.x, p.point.y]
for p in detection.keypoints.data], dtype=np.int16)
u = np.array(keypoints_2d[:, 1]).clip(0, depth_info.height - 1)
v = np.array(keypoints_2d[:, 0]).clip(0, depth_info.width - 1)

Expand All @@ -186,7 +201,8 @@ def convert_keypoints_to_3d(self,
px, py, fx, fy = k[2], k[5], k[0], k[4]
x = z * (v - px) / fx
y = z * (u - py) / fy
points_3d = np.dstack([x, y, z]).reshape(-1, 3) / self.depth_image_units_divisor # convert to meters
points_3d = np.dstack([x, y, z]).reshape(-1, 3) / \
self.depth_image_units_divisor # convert to meters

# generate message
msg_array = KeyPoint3DArray()
Expand All @@ -203,7 +219,7 @@ def convert_keypoints_to_3d(self,
return msg_array

def get_transform(self, frame_id: str) -> Tuple[np.ndarray]:
# transform position from point cloud frame to target_frame
# transform position from image frame to target_frame
rotation = None
translation = None

Expand Down Expand Up @@ -296,5 +312,7 @@ def qv_mult(q: np.ndarray, v: np.ndarray) -> np.ndarray:

def main():
rclpy.init()
rclpy.spin(Detect3DNode())
node = Detect3DNode()
rclpy.spin(node)
node.destroy_node()
rclpy.shutdown()
9 changes: 5 additions & 4 deletions yolov8_ros/yolov8_ros/yolov8_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,11 @@ def __init__(self) -> None:
# services
self._srv = self.create_service(SetBool, "enable", self.enable_cb)

def enable_cb(self,
req: SetBool.Request,
res: SetBool.Response
) -> SetBool.Response:
def enable_cb(
self,
req: SetBool.Request,
res: SetBool.Response
) -> SetBool.Response:
self.enable = req.data
res.success = True
return res
Expand Down

0 comments on commit 0719a3e

Please sign in to comment.