From 916ba039b779819f1ce1a078e1de3ede28a56a81 Mon Sep 17 00:00:00 2001 From: Justin Yip Date: Fri, 6 Dec 2024 15:26:31 -0800 Subject: [PATCH] [DNL] executorch export faster-rcnn --- torchvision/ops/boxes.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 309990ea03a..f9780c39a2d 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -69,10 +69,11 @@ def batched_nms( _log_api_usage_once(batched_nms) # Benchmarks that drove the following thresholds are at # https://github.com/pytorch/vision/issues/1311#issuecomment-781329339 - if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing(): - return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold) - else: - return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold) + return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold) + #if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing(): + # return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold) + #else: + # return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold) @torch.jit._script_if_tracing @@ -104,7 +105,8 @@ def _batched_nms_vanilla( ) -> Tensor: # Based on Detectron2 implementation, just manually call nms() on each class independently keep_mask = torch.zeros_like(scores, dtype=torch.bool) - for class_id in torch.unique(idxs): + #for class_id in torch.unique(idxs): + for class_id in idxs: curr_indices = torch.where(idxs == class_id)[0] curr_keep_indices = nms(boxes[curr_indices], scores[curr_indices], iou_threshold) keep_mask[curr_indices[curr_keep_indices]] = True