From 6c18d259c614b2a8d9b9cd0f615529dc50861d82 Mon Sep 17 00:00:00 2001 From: dai-z Date: Wed, 22 Dec 2021 20:18:26 +0800 Subject: [PATCH] [bug fixed] Fix bbox coord in service --- darknet_images.py | 6 +++++- detect_service.py | 17 ++++++++++++++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/darknet_images.py b/darknet_images.py index 17ac91731a7..2d55ac09fda 100644 --- a/darknet_images.py +++ b/darknet_images.py @@ -105,6 +105,7 @@ def image_detection(image_path, network, class_names, class_colors, thresh): darknet_image = darknet.make_image(width, height, 3) image = cv2.imread(image_path) + ori_height, ori_width = image.shape[:2] image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image_resized = cv2.resize(image_rgb, (width, height), interpolation=cv2.INTER_LINEAR) @@ -113,6 +114,8 @@ def image_detection(image_path, network, class_names, class_colors, thresh): detections = darknet.detect_image(network, class_names, darknet_image, thresh=thresh) darknet.free_image(darknet_image) image = darknet.draw_boxes(detections, image_resized, class_colors) + image = cv2.resize(image, (ori_width, ori_height), interpolation=cv2.INTER_LINEAR) + print(image.shape) return cv2.cvtColor(image, cv2.COLOR_BGR2RGB), detections @@ -181,6 +184,7 @@ def batch_detection_example(): args.weights, batch_size=batch_size ) + print(class_colors) image_names = ['data/horses.jpg', 'data/horses.jpg', 'data/eagle.jpg'] images = [cv2.imread(image) for image in image_names] images, detections, = batch_detection(network, images, class_names, @@ -194,7 +198,7 @@ def main(): args = parser() check_arguments_errors(args) - random.seed(3) # deterministic bbox colors + random.seed(0) # deterministic bbox colors network, class_names, class_colors = darknet.load_network( args.config_file, args.data_file, diff --git a/detect_service.py b/detect_service.py index 791065be98d..4fd3eb2ab38 100644 --- a/detect_service.py +++ b/detect_service.py @@ -88,6 +88,10 @@ def handle_image(req: DetectRequest, network, class_names, thresh): # print("Returning [%s + %s = %s]"%(req.a, req.b, (req.a + req.b))) image = CvBridge().imgmsg_to_cv2(req.image) + vis = image.copy() + width = darknet.network_width(network) + height = darknet.network_height(network) + ori_height, ori_width = image.shape[:2] # prev_time = time.time() # [(label_name, prob(%), x, y, w, h)] detections = image_detection(image, network, class_names, thresh) @@ -95,10 +99,17 @@ def handle_image(req: DetectRequest, network, class_names, thresh): # fps = int(1/(time.time() - prev_time)) # print("FPS: {}".format(fps)) bboxes = [] + scale_w = ori_width / width + scale_h = ori_height / height for det in detections: + left, top, right, bottom = darknet.bbox2points(det[2]) + left *= scale_w + right *= scale_w + top *= scale_h + bottom *= scale_h + vis = cv2.rectangle(vis, (int(left), int(top)), (int(right), int(bottom)), (0,0,255), 5) # label, x_min ,x_max, y_min, y_max, prob - box = BBox(0, det[2][0], det[2][1], det[2][0] + det[2][2], - det[2][1] + det[2][3], float(det[1])) + box = BBox(0, left, top, right, bottom, float(det[1])) bboxes.append(box) result = DetectResult(0, len(bboxes), bboxes) @@ -126,4 +137,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main()