From 737b5b4f08ddeee2a76d40911b1d5dbfcc82f0f2 Mon Sep 17 00:00:00 2001 From: momo609 <963372609@qq.com> Date: Thu, 30 May 2024 19:22:04 +0800 Subject: [PATCH] fix gather_point bug. --- .../furthest_point_sampling_with_dist_npu.cpp | 8 ++++---- mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp | 16 +++++++++++++--- mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp | 3 ++- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/furthest_point_sampling_with_dist_npu.cpp b/mmcv/ops/csrc/pytorch/npu/furthest_point_sampling_with_dist_npu.cpp index 364d3bfa9a..24317a06bb 100644 --- a/mmcv/ops/csrc/pytorch/npu/furthest_point_sampling_with_dist_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/furthest_point_sampling_with_dist_npu.cpp @@ -6,11 +6,11 @@ void furthest_point_sampling_with_dist_npu(Tensor points_tensor, Tensor temp_tensor, Tensor idx_tensor, int b, int n, int m) { - auto output_size = {b, m}; - at::Tensor result = - at::empty(output_size, points_tensor.options().dtype(at::kInt)); + TORCH_CHECK( + (points_tensor.sizes()[1] >= m), + "the num of sampled points should smaller than total num of points."); EXEC_NPU_CMD(aclnnFurthestPointSamplingWithDist, points_tensor, temp_tensor, - m, result); + m, idx_tensor); } void furthest_point_sampling_with_dist_forward_impl(Tensor points_tensor, diff --git a/mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp b/mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp index cf3a577ce1..991e6038db 100644 --- a/mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp @@ -24,6 +24,12 @@ void gather_points_forward_npu(int b, int c, int n, int npoints, void gather_points_backward_npu(int b, int c, int n, int npoints, const Tensor grad_out, const Tensor idx, Tensor grad_points) { + at::Tensor grad_out_cast = grad_out; + at::Tensor grad_points_cast = grad_points; + if (grad_out.scalar_type() == at::ScalarType::Half) { + grad_out_cast = grad_out.to(at::kFloat); + grad_points_cast = grad_points.to(at::kFloat); + } at::Tensor indices = idx; if (idx.scalar_type() != at::ScalarType::Int) { indices = idx.to(at::kInt); @@ -37,11 +43,11 @@ void gather_points_backward_npu(int b, int c, int n, int npoints, for (uint64_t i = 0; i < shape.size(); i++) { pad_size.emplace_back(shape[i]); } - at::Tensor trans_grad_points = grad_points.transpose(1, 2).contiguous(); + at::Tensor trans_grad_points = grad_points_cast.transpose(1, 2).contiguous(); at::Tensor grad_points_view = trans_grad_points.view( {trans_grad_points.sizes()[0] * trans_grad_points.sizes()[1], trans_grad_points.sizes()[2]}); - at::Tensor trans_grad_out = grad_out.transpose(1, 2).contiguous(); + at::Tensor trans_grad_out = grad_out_cast.transpose(1, 2).contiguous(); trans_grad_out = trans_grad_out.view( {trans_grad_out.sizes()[0] * trans_grad_out.sizes()[1], trans_grad_out.sizes()[2]}); @@ -63,7 +69,11 @@ void gather_points_backward_npu(int b, int c, int n, int npoints, at::Tensor grad_points_result = grad_points_view.view(trans_grad_points.sizes()); grad_points_result = grad_points_result.transpose(1, 2); - grad_points.copy_(grad_points_result); + at::Tensor grad_points_result_cast = grad_points_result; + if (grad_out.scalar_type() == at::ScalarType::Half) { + grad_points_result_cast = grad_points_result.to(at::kHalf); + } + grad_points.copy_(grad_points_result_cast); } void gather_points_forward_impl(int b, int c, int n, int npoints, diff --git a/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp b/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp index 2b3af2575c..b7015439b9 100644 --- a/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp @@ -70,7 +70,8 @@ void roi_pool_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax, .Attr("spatial_scale_w", spatial_scale) .Attr("pool_channel", pooled_channel) .Run(); - at::Tensor res = NpuUtils::format_contiguous(result); + at::Tensor result = y.transpose(2, 3).transpose(1, 2); + at::Tensor res = result.contiguous(); grad_input.copy_(res); }