diff --git a/models/yolo_v5_object_detector.py b/models/yolo_v5_object_detector.py index 274ecc3..8ce73dd 100644 --- a/models/yolo_v5_object_detector.py +++ b/models/yolo_v5_object_detector.py @@ -110,7 +110,9 @@ def non_max_suppression(prediction, logits, conf_thres=0.6, iou_thres=0.45, clas continue # Compute conf - x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf + obj_conf = x[:, 4:5].clone() + cls_conf = x[:, 5:].clone() + x[:, 5:] = obj_conf * cls_conf # conf = obj_conf * cls_conf # log_ *= x[:, 4:5] # Box (center x, center y, width, height) to (x1, y1, x2, y2) box = xywh2xyxy(x[:, :4])