From 61e41b3e992c48a695f9724410b5b3a3506af74b Mon Sep 17 00:00:00 2001 From: Binary2355 Date: Fri, 14 Jun 2024 15:54:29 +0800 Subject: [PATCH] npu knn/tnn bugfix --- mmcv/ops/csrc/pytorch/npu/knn_npu.cpp | 21 +++++++++++++++++++++ mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp | 20 ++++++++++++++++++++ mmcv/ops/knn.py | 15 +++++++++++++-- mmcv/ops/three_nn.py | 14 ++++++++++++++ 4 files changed, 68 insertions(+), 2 deletions(-) create mode 100644 mmcv/ops/csrc/pytorch/npu/knn_npu.cpp create mode 100644 mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp diff --git a/mmcv/ops/csrc/pytorch/npu/knn_npu.cpp b/mmcv/ops/csrc/pytorch/npu/knn_npu.cpp new file mode 100644 index 0000000000..c4a1bcbd25 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/knn_npu.cpp @@ -0,0 +1,21 @@ +#include "pytorch_npu_helper.hpp" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void knn_forward_npu(int b, int n, int m, int nsample, const Tensor xyz, + const Tensor new_xyz, Tensor idx, Tensor dist2) { + // transpose known from [B, N, 3] to [B, 3, N] + at::Tensor source = xyz.transpose(2, 1).contiguous(); + at::Tensor target = new_xyz.contiguous(); + + bool is_from_knn = true; + EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, dist2); +} + +void knn_forward_impl(int b, int n, int m, int nsample, const Tensor xyz, + const Tensor new_xyz, Tensor idx, Tensor dist2); + +REGISTER_NPU_IMPL(knn_forward_impl, knn_forward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp b/mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp new file mode 100644 index 0000000000..6740a731bc --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp @@ -0,0 +1,20 @@ +#include "pytorch_npu_helper.hpp" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void three_nn_forward_npu(int b, int n, int m, const Tensor unknown, + const Tensor known, Tensor dist2, Tensor idx) { + at::Tensor source = known.contiguous(); + at::Tensor target = unknown.contiguous(); + + bool is_from_knn = false; + EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, dist2); +} + +void three_nn_forward_impl(int b, int n, int m, const Tensor unknown, + const Tensor known, Tensor dist2, Tensor idx); + +REGISTER_NPU_IMPL(three_nn_forward_impl, three_nn_forward_npu); diff --git a/mmcv/ops/knn.py b/mmcv/ops/knn.py index 48ce92f925..09c8a68b04 100644 --- a/mmcv/ops/knn.py +++ b/mmcv/ops/knn.py @@ -55,12 +55,23 @@ def forward(ctx, center_xyz_device = center_xyz.get_device() assert center_xyz_device == xyz.get_device(), \ 'center_xyz and xyz should be put on the same device' - if torch.cuda.current_device() != center_xyz_device: - torch.cuda.set_device(center_xyz_device) + if xyz.device.type != 'npu': + if torch.cuda.current_device() != center_xyz_device: + torch.cuda.set_device(center_xyz_device) B, npoint, _ = center_xyz.shape N = xyz.shape[1] + if xyz.device.type == 'npu': + dist = center_xyz.new_zeros((B, npoint, N)).float() + ext_module.knn_forward( + xyz, center_xyz, torch.Tensor([]).npu(), dist, b=B, n=N, m=npoint, nsample=k) + dist2, idx = torch.topk(dist, k, dim=2, largest=False, sorted=True) + zeros_idx = torch.zeros(xyz.shape[0], center_xyz.shape[1], k, dtype=torch.int32).npu() + idx.where(dist2 >= 1e10, zeros_idx) + idx = idx.transpose(2, 1).contiguous() # [B, k, npoint] + return idx.type(torch.IntTensor) + idx = center_xyz.new_zeros((B, npoint, k)).int() dist2 = center_xyz.new_zeros((B, npoint, k)).float() diff --git a/mmcv/ops/three_nn.py b/mmcv/ops/three_nn.py index d41b9789cf..2c6a6a93b1 100644 --- a/mmcv/ops/three_nn.py +++ b/mmcv/ops/three_nn.py @@ -34,6 +34,20 @@ def forward(ctx: Any, target: torch.Tensor, B, N, _ = target.size() m = source.size(1) + if source.device.type == 'npu': + # strict to fp32 + source = source.transpose(2, 1).contiguous() + dtype_ = source.dtype + if dtype_ == torch.float16: + target = target.float() + source = source.float() + dist = target.new_empty(B, N, m) + ext_module.three_nn_forward(target, source, dist, torch.Tensor([]).npu(), b=B, n=N, m=m) + dist2, idx = torch.topk(dist, 3, dim=2, largest=False, sorted=True) + dist2 = torch.sqrt(dist2) + if dtype_ == torch.float16: + dist2 = dist2.half() + return dist2, idx.type(torch.IntTensor) dist2 = target.new_empty(B, N, 3) idx = target.new_empty(B, N, 3, dtype=torch.int32)