Skip to content

Commit

Permalink
[bug fixed] Fix bbox coord in service
Browse files Browse the repository at this point in the history
  • Loading branch information
Dai-z committed Dec 22, 2021
1 parent 6fb0e18 commit 6c18d25
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
6 changes: 5 additions & 1 deletion darknet_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def image_detection(image_path, network, class_names, class_colors, thresh):
darknet_image = darknet.make_image(width, height, 3)

image = cv2.imread(image_path)
ori_height, ori_width = image.shape[:2]
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_resized = cv2.resize(image_rgb, (width, height),
interpolation=cv2.INTER_LINEAR)
Expand All @@ -113,6 +114,8 @@ def image_detection(image_path, network, class_names, class_colors, thresh):
detections = darknet.detect_image(network, class_names, darknet_image, thresh=thresh)
darknet.free_image(darknet_image)
image = darknet.draw_boxes(detections, image_resized, class_colors)
image = cv2.resize(image, (ori_width, ori_height), interpolation=cv2.INTER_LINEAR)
print(image.shape)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB), detections


Expand Down Expand Up @@ -181,6 +184,7 @@ def batch_detection_example():
args.weights,
batch_size=batch_size
)
print(class_colors)
image_names = ['data/horses.jpg', 'data/horses.jpg', 'data/eagle.jpg']
images = [cv2.imread(image) for image in image_names]
images, detections, = batch_detection(network, images, class_names,
Expand All @@ -194,7 +198,7 @@ def main():
args = parser()
check_arguments_errors(args)

random.seed(3) # deterministic bbox colors
random.seed(0) # deterministic bbox colors
network, class_names, class_colors = darknet.load_network(
args.config_file,
args.data_file,
Expand Down
17 changes: 14 additions & 3 deletions detect_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,28 @@ def handle_image(req: DetectRequest, network, class_names, thresh):
# print("Returning [%s + %s = %s]"%(req.a, req.b, (req.a + req.b)))

image = CvBridge().imgmsg_to_cv2(req.image)
vis = image.copy()
width = darknet.network_width(network)
height = darknet.network_height(network)
ori_height, ori_width = image.shape[:2]
# prev_time = time.time()
# [(label_name, prob(%), x, y, w, h)]
detections = image_detection(image, network, class_names, thresh)
# darknet.print_detections(detections, args.ext_output)
# fps = int(1/(time.time() - prev_time))
# print("FPS: {}".format(fps))
bboxes = []
scale_w = ori_width / width
scale_h = ori_height / height
for det in detections:
left, top, right, bottom = darknet.bbox2points(det[2])
left *= scale_w
right *= scale_w
top *= scale_h
bottom *= scale_h
vis = cv2.rectangle(vis, (int(left), int(top)), (int(right), int(bottom)), (0,0,255), 5)
# label, x_min ,x_max, y_min, y_max, prob
box = BBox(0, det[2][0], det[2][1], det[2][0] + det[2][2],
det[2][1] + det[2][3], float(det[1]))
box = BBox(0, left, top, right, bottom, float(det[1]))
bboxes.append(box)
result = DetectResult(0, len(bboxes), bboxes)

Expand Down Expand Up @@ -126,4 +137,4 @@ def main():


if __name__ == "__main__":
main()
main()

0 comments on commit 6c18d25

Please sign in to comment.