Skip to content

Commit

Permalink
add spconv
Browse files Browse the repository at this point in the history
  • Loading branch information
LYX-SOUL committed Oct 14, 2024
1 parent c7c11de commit 88ecac2
Show file tree
Hide file tree
Showing 6 changed files with 2,345 additions and 0 deletions.
13 changes: 13 additions & 0 deletions diopi_test/python/conformance/diopi_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6206,3 +6206,16 @@ def spmm(input, mat2) -> Tensor:
ret = func(input.context(), out, input, mat2)
check_returncode(ret)
return out

def spconv(in_feat, kernel,neighbor_map,sum_nnz,neighbor_address,q_neighbor_address,output_size,
qsum_nnz,transpose,allow_tf32,allow_fp16) -> Tensor:

out_channel = kernel.size().data[2]
func = check_function("diopiSpConv")
# at::Tensor out_feat = torch::zeros({output_size, out_channel},
# at::device(in_feat.device()).dtype(in_feat.scalar_type()));
out_feat = Tensor(list([output_size, out_channel]), dtype=in_feat.get_dtype())
ret = func(in_feat.context(),out_feat,in_feat,kernel,neighbor_map,sum_nnz,neighbor_address,
q_neighbor_address,output_size,qsum_nnz,transpose,allow_tf32,allow_fp16 )
check_returncode(ret)
return out_feat
1 change: 1 addition & 0 deletions diopi_test/python/conformance/mytest.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cells":[{"cell_type":"code","execution_count":2,"metadata":{},"outputs":[{"ename":"ModuleNotFoundError","evalue":"No module named 'diopilib'","output_type":"error","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)","Cell \u001b[0;32mIn[2], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mdiopilib\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mconformance\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdiopi_functions\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m spmm,spconv\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mconformance\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdiopi_runtime\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Tensor\n","\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'diopilib'"]}],"source":["import torch\n","import diopilib\n","from conformance.diopi_functions import spmm,spconv\n","from conformance.diopi_runtime import Tensor\n","import numpy as np\n","# from torch.nn.functional import dropout\n","# sparse matrix:\n","# 0.5, 0, 1\n","# 0, 0, 2\n","# 1, 3, 0.6\n","\n","# M, K, N = 4096, 4096, 128\n","\n","# a = torch.randn((M,K),dtype=torch.float32)\n","# a = dropout(a, p=0.9)\n","# sparse_a = a.to_sparse_csr()\n","# print(a)\n","# print(sparse_a)\n","# b = np.random.randn(K,N).astype(np.float32)\n","# print(b)\n","\n","# input = Tensor.from_numpy(b)\n","# row_ptr = Tensor.from_numpy(sparse_a.crow_indices().numpy().astype(np.int32))\n","# col_ind = Tensor.from_numpy(sparse_a.col_indices().numpy().astype(np.int32))\n","# values = Tensor.from_numpy(sparse_a.values().numpy().astype(np.float32))\n","# print(list(input.size().data))\n","# c = spmm(row_ptr, col_ind, values, input)\n","# c_ref = a @ torch.from_numpy(b)\n","# # c = rbrmsr_spmm(sparse_a.crow_indices(), sparse_a.col_indices(), sparse_a.values(), b)\n","\n","\n","spconv_data = torch.load('spconv_data.pth')\n","in_feat = spconv_data['in_feats']\n","kernel = spconv_data['kernel']\n","neighbor_map = spconv_data['neighbor_map']\n","sum_nnz = spconv_data['sum_nnz']\n","neighbor_address = spconv_data['neighbor_address']\n","q_neighbor_address = spconv_data['q_neighbor_address']\n","output_size = spconv_data['output_size']\n","qsum_nnz = spconv_data['qsum_nnz'].item()\n","\n","in_feat = Tensor.from_numpy(in_feat.numpy().astype(np.float16))\n","kernel = Tensor.from_numpy(kernel.numpy().astype(np.float16))\n","neighbor_map = Tensor.from_numpy(neighbor_map.numpy().astype(np.int32))\n","neighbor_address = Tensor.from_numpy(neighbor_address.numpy().astype(np.int32))\n","q_neighbor_address = Tensor.from_numpy(q_neighbor_address.numpy().astype(np.int32))\n","\n","out = spconv(in_feat, kernel,neighbor_map,sum_nnz,neighbor_address,q_neighbor_address,output_size,\n"," qsum_nnz,False,True,True) \n","\n","print(out)"]},{"cell_type":"code","execution_count":2,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["ours: tensor([[ 1.9180e+02, -3.3436e+02, -2.5628e+02, ..., -1.2209e+02,\n"," 1.2449e+02, 8.9544e+00],\n"," [ 5.4242e+00, -9.1456e+01, -8.8789e+01, ..., 2.9239e+02,\n"," 5.9588e+02, -6.7169e+01],\n"," [ 9.6391e+01, -2.8461e+02, -8.9830e+01, ..., 2.4198e+02,\n"," 1.9330e+02, -2.7285e+02],\n"," ...,\n"," [-9.6388e+01, -1.8851e+02, -3.9614e+02, ..., 3.3607e+02,\n"," 9.8614e+01, -2.6860e+02],\n"," [-8.1585e+01, -1.1838e+02, -1.9467e+02, ..., 3.6153e+02,\n"," -1.0657e+02, -3.8279e+02],\n"," [ 1.8823e+02, -3.1072e+02, -2.3838e+02, ..., -3.6757e-01,\n"," -5.7795e+01, -1.2785e+02]])\n","ref: tensor([[ 1.9180e+02, -3.3436e+02, -2.5628e+02, ..., -1.2209e+02,\n"," 1.2449e+02, 8.9544e+00],\n"," [ 5.4243e+00, -9.1456e+01, -8.8789e+01, ..., 2.9239e+02,\n"," 5.9588e+02, -6.7170e+01],\n"," [ 9.6391e+01, -2.8461e+02, -8.9830e+01, ..., 2.4198e+02,\n"," 1.9330e+02, -2.7285e+02],\n"," ...,\n"," [-9.6388e+01, -1.8851e+02, -3.9614e+02, ..., 3.3607e+02,\n"," 9.8614e+01, -2.6860e+02],\n"," [-8.1585e+01, -1.1838e+02, -1.9467e+02, ..., 3.6153e+02,\n"," -1.0657e+02, -3.8279e+02],\n"," [ 1.8823e+02, -3.1072e+02, -2.3838e+02, ..., -3.6760e-01,\n"," -5.7795e+01, -1.2785e+02]])\n"]}],"source":["c_ours = torch.from_numpy(c.numpy())\n","print(\"ours:\", c_ours)\n","print(\"ref: \", c_ref)\n","assert torch.allclose(c_ours, c_ref, rtol=1e-03, atol=1e-03)"]}],"metadata":{"kernelspec":{"display_name":"eda","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.8.18"}},"nbformat":4,"nbformat_minor":2}
24 changes: 24 additions & 0 deletions impl/torch/functions/functions_sparse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,29 @@ diopiError_t diopiSpMM(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiC
return diopiErrorOccurred;
}


extern "C" diopiError_t diopiSpConv(diopiContextHandle_t ctx, diopiTensorHandle_t out_feat, diopiTensorHandle_t in_feat,
diopicTensorHandle_t kernel, diopiTensorHandle_t neighbor_map,const int sum_nnz,
diopiTensorHandle_t neighbor_address, diopiTensorHandle_t q_neighbor_address, const int output_size,
const int qsum_nnz, const bool transpose, const bool allow_tf32, const bool allow_fp16 ) {

impl::aten::setCurStream(ctx);

auto atIn_feat = impl::aten::buildATen(in_feat);
auto atOut_feat = impl::aten::buildATen(out_feat);
auto atKernel = impl::aten::buildATen(kernel);
auto atNeighbor_map = impl::aten::buildATen(neighbor_map);
auto atNeighbor_address = impl::aten::buildATen(neighbor_address);
auto atQ_neighbor_address = impl::aten::buildATen(q_neighbor_address);
atOut_feat.zero_();

sparse::ops::conv_forward_fetch_on_demand_cuda(atIn_feat, atOut_feat, atKernel, atNeighbor_map, sum_nnz,
atNeighbor_address,atQ_neighbor_address,output_size,
qsum_nnz,transpose,allow_tf32,allow_fp16);

return diopiSuccess;

}

} // namespace cuda
} // namespace impl
Loading

0 comments on commit 88ecac2

Please sign in to comment.