Skip to content

Commit

Permalink
fix gather_point bug.
Browse files Browse the repository at this point in the history
  • Loading branch information
momo609 committed Jun 13, 2024
1 parent abf8ca7 commit 737b5b4
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ void furthest_point_sampling_with_dist_npu(Tensor points_tensor,
Tensor temp_tensor,
Tensor idx_tensor, int b, int n,
int m) {
auto output_size = {b, m};
at::Tensor result =
at::empty(output_size, points_tensor.options().dtype(at::kInt));
TORCH_CHECK(
(points_tensor.sizes()[1] >= m),
"the num of sampled points should smaller than total num of points.");
EXEC_NPU_CMD(aclnnFurthestPointSamplingWithDist, points_tensor, temp_tensor,
m, result);
m, idx_tensor);
}

void furthest_point_sampling_with_dist_forward_impl(Tensor points_tensor,
Expand Down
16 changes: 13 additions & 3 deletions mmcv/ops/csrc/pytorch/npu/gather_points_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ void gather_points_forward_npu(int b, int c, int n, int npoints,
void gather_points_backward_npu(int b, int c, int n, int npoints,
const Tensor grad_out, const Tensor idx,
Tensor grad_points) {
at::Tensor grad_out_cast = grad_out;
at::Tensor grad_points_cast = grad_points;
if (grad_out.scalar_type() == at::ScalarType::Half) {
grad_out_cast = grad_out.to(at::kFloat);
grad_points_cast = grad_points.to(at::kFloat);
}
at::Tensor indices = idx;
if (idx.scalar_type() != at::ScalarType::Int) {
indices = idx.to(at::kInt);
Expand All @@ -37,11 +43,11 @@ void gather_points_backward_npu(int b, int c, int n, int npoints,
for (uint64_t i = 0; i < shape.size(); i++) {
pad_size.emplace_back(shape[i]);
}
at::Tensor trans_grad_points = grad_points.transpose(1, 2).contiguous();
at::Tensor trans_grad_points = grad_points_cast.transpose(1, 2).contiguous();
at::Tensor grad_points_view = trans_grad_points.view(
{trans_grad_points.sizes()[0] * trans_grad_points.sizes()[1],
trans_grad_points.sizes()[2]});
at::Tensor trans_grad_out = grad_out.transpose(1, 2).contiguous();
at::Tensor trans_grad_out = grad_out_cast.transpose(1, 2).contiguous();
trans_grad_out = trans_grad_out.view(
{trans_grad_out.sizes()[0] * trans_grad_out.sizes()[1],
trans_grad_out.sizes()[2]});
Expand All @@ -63,7 +69,11 @@ void gather_points_backward_npu(int b, int c, int n, int npoints,
at::Tensor grad_points_result =
grad_points_view.view(trans_grad_points.sizes());
grad_points_result = grad_points_result.transpose(1, 2);
grad_points.copy_(grad_points_result);
at::Tensor grad_points_result_cast = grad_points_result;
if (grad_out.scalar_type() == at::ScalarType::Half) {
grad_points_result_cast = grad_points_result.to(at::kHalf);
}
grad_points.copy_(grad_points_result_cast);
}

void gather_points_forward_impl(int b, int c, int n, int npoints,
Expand Down
3 changes: 2 additions & 1 deletion mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ void roi_pool_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax,
.Attr("spatial_scale_w", spatial_scale)
.Attr("pool_channel", pooled_channel)
.Run();
at::Tensor res = NpuUtils::format_contiguous(result);
at::Tensor result = y.transpose(2, 3).transpose(1, 2);
at::Tensor res = result.contiguous();
grad_input.copy_(res);
}

Expand Down

0 comments on commit 737b5b4

Please sign in to comment.