diff --git a/mmcv/ops/points_in_boxes.py b/mmcv/ops/points_in_boxes.py index 4915e6b573..e58a6e2a12 100644 --- a/mmcv/ops/points_in_boxes.py +++ b/mmcv/ops/points_in_boxes.py @@ -47,8 +47,11 @@ def points_in_boxes_part(points: Tensor, boxes: Tensor) -> Tensor: points_device = points.get_device() assert points_device == boxes.get_device(), \ 'Points and boxes should be put on the same device' - if torch.cuda.current_device() != points_device: - torch.cuda.set_device(points_device) + if points.device.type != 'npu': + if torch.cuda.current_device() != points_device: + torch.cuda.set_device(points_device) + else: + boxes[:, :, 2] += boxes[:, :, 5] / 2.0 ext_module.points_in_boxes_part_forward(boxes.contiguous(), points.contiguous(), @@ -127,8 +130,9 @@ def points_in_boxes_all(points: Tensor, boxes: Tensor) -> Tensor: points_device = points.get_device() assert points_device == boxes.get_device(), \ 'Points and boxes should be put on the same device' - if torch.cuda.current_device() != points_device: - torch.cuda.set_device(points_device) + if points.device.type != 'npu': + if torch.cuda.current_device() != points_device: + torch.cuda.set_device(points_device) ext_module.points_in_boxes_all_forward(boxes.contiguous(), points.contiguous(),