From 88ecac21b17e2b7fc4a7456ef5a8f309d5c914ef Mon Sep 17 00:00:00 2001 From: lianyaoxiu <2294369811@qq.com> Date: Mon, 14 Oct 2024 07:42:18 +0000 Subject: [PATCH] add spconv --- .../python/conformance/diopi_functions.py | 13 + diopi_test/python/conformance/mytest.ipynb | 1 + impl/torch/functions/functions_sparse.cpp | 24 + .../functions/functions_sparse/spconv.cu | 2266 +++++++++++++++++ impl/torch/sparse_kernel.h | 13 + proto/include/diopi/functions_sparse.h | 28 + 6 files changed, 2345 insertions(+) create mode 100644 diopi_test/python/conformance/mytest.ipynb create mode 100644 impl/torch/functions/functions_sparse/spconv.cu diff --git a/diopi_test/python/conformance/diopi_functions.py b/diopi_test/python/conformance/diopi_functions.py index 66bb2ac98..ba72f00e4 100644 --- a/diopi_test/python/conformance/diopi_functions.py +++ b/diopi_test/python/conformance/diopi_functions.py @@ -6206,3 +6206,16 @@ def spmm(input, mat2) -> Tensor: ret = func(input.context(), out, input, mat2) check_returncode(ret) return out + +def spconv(in_feat, kernel,neighbor_map,sum_nnz,neighbor_address,q_neighbor_address,output_size, + qsum_nnz,transpose,allow_tf32,allow_fp16) -> Tensor: + + out_channel = kernel.size().data[2] + func = check_function("diopiSpConv") + # at::Tensor out_feat = torch::zeros({output_size, out_channel}, + # at::device(in_feat.device()).dtype(in_feat.scalar_type())); + out_feat = Tensor(list([output_size, out_channel]), dtype=in_feat.get_dtype()) + ret = func(in_feat.context(),out_feat,in_feat,kernel,neighbor_map,sum_nnz,neighbor_address, + q_neighbor_address,output_size,qsum_nnz,transpose,allow_tf32,allow_fp16 ) + check_returncode(ret) + return out_feat \ No newline at end of file diff --git a/diopi_test/python/conformance/mytest.ipynb b/diopi_test/python/conformance/mytest.ipynb new file mode 100644 index 000000000..05e7b8895 --- /dev/null +++ b/diopi_test/python/conformance/mytest.ipynb @@ -0,0 +1 @@ +{"cells":[{"cell_type":"code","execution_count":2,"metadata":{},"outputs":[{"ename":"ModuleNotFoundError","evalue":"No module named 'diopilib'","output_type":"error","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)","Cell \u001b[0;32mIn[2], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mdiopilib\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mconformance\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdiopi_functions\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m spmm,spconv\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mconformance\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdiopi_runtime\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Tensor\n","\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'diopilib'"]}],"source":["import torch\n","import diopilib\n","from conformance.diopi_functions import spmm,spconv\n","from conformance.diopi_runtime import Tensor\n","import numpy as np\n","# from torch.nn.functional import dropout\n","# sparse matrix:\n","# 0.5, 0, 1\n","# 0, 0, 2\n","# 1, 3, 0.6\n","\n","# M, K, N = 4096, 4096, 128\n","\n","# a = torch.randn((M,K),dtype=torch.float32)\n","# a = dropout(a, p=0.9)\n","# sparse_a = a.to_sparse_csr()\n","# print(a)\n","# print(sparse_a)\n","# b = np.random.randn(K,N).astype(np.float32)\n","# print(b)\n","\n","# input = Tensor.from_numpy(b)\n","# row_ptr = Tensor.from_numpy(sparse_a.crow_indices().numpy().astype(np.int32))\n","# col_ind = Tensor.from_numpy(sparse_a.col_indices().numpy().astype(np.int32))\n","# values = Tensor.from_numpy(sparse_a.values().numpy().astype(np.float32))\n","# print(list(input.size().data))\n","# c = spmm(row_ptr, col_ind, values, input)\n","# c_ref = a @ torch.from_numpy(b)\n","# # c = rbrmsr_spmm(sparse_a.crow_indices(), sparse_a.col_indices(), sparse_a.values(), b)\n","\n","\n","spconv_data = torch.load('spconv_data.pth')\n","in_feat = spconv_data['in_feats']\n","kernel = spconv_data['kernel']\n","neighbor_map = spconv_data['neighbor_map']\n","sum_nnz = spconv_data['sum_nnz']\n","neighbor_address = spconv_data['neighbor_address']\n","q_neighbor_address = spconv_data['q_neighbor_address']\n","output_size = spconv_data['output_size']\n","qsum_nnz = spconv_data['qsum_nnz'].item()\n","\n","in_feat = Tensor.from_numpy(in_feat.numpy().astype(np.float16))\n","kernel = Tensor.from_numpy(kernel.numpy().astype(np.float16))\n","neighbor_map = Tensor.from_numpy(neighbor_map.numpy().astype(np.int32))\n","neighbor_address = Tensor.from_numpy(neighbor_address.numpy().astype(np.int32))\n","q_neighbor_address = Tensor.from_numpy(q_neighbor_address.numpy().astype(np.int32))\n","\n","out = spconv(in_feat, kernel,neighbor_map,sum_nnz,neighbor_address,q_neighbor_address,output_size,\n"," qsum_nnz,False,True,True) \n","\n","print(out)"]},{"cell_type":"code","execution_count":2,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["ours: tensor([[ 1.9180e+02, -3.3436e+02, -2.5628e+02, ..., -1.2209e+02,\n"," 1.2449e+02, 8.9544e+00],\n"," [ 5.4242e+00, -9.1456e+01, -8.8789e+01, ..., 2.9239e+02,\n"," 5.9588e+02, -6.7169e+01],\n"," [ 9.6391e+01, -2.8461e+02, -8.9830e+01, ..., 2.4198e+02,\n"," 1.9330e+02, -2.7285e+02],\n"," ...,\n"," [-9.6388e+01, -1.8851e+02, -3.9614e+02, ..., 3.3607e+02,\n"," 9.8614e+01, -2.6860e+02],\n"," [-8.1585e+01, -1.1838e+02, -1.9467e+02, ..., 3.6153e+02,\n"," -1.0657e+02, -3.8279e+02],\n"," [ 1.8823e+02, -3.1072e+02, -2.3838e+02, ..., -3.6757e-01,\n"," -5.7795e+01, -1.2785e+02]])\n","ref: tensor([[ 1.9180e+02, -3.3436e+02, -2.5628e+02, ..., -1.2209e+02,\n"," 1.2449e+02, 8.9544e+00],\n"," [ 5.4243e+00, -9.1456e+01, -8.8789e+01, ..., 2.9239e+02,\n"," 5.9588e+02, -6.7170e+01],\n"," [ 9.6391e+01, -2.8461e+02, -8.9830e+01, ..., 2.4198e+02,\n"," 1.9330e+02, -2.7285e+02],\n"," ...,\n"," [-9.6388e+01, -1.8851e+02, -3.9614e+02, ..., 3.3607e+02,\n"," 9.8614e+01, -2.6860e+02],\n"," [-8.1585e+01, -1.1838e+02, -1.9467e+02, ..., 3.6153e+02,\n"," -1.0657e+02, -3.8279e+02],\n"," [ 1.8823e+02, -3.1072e+02, -2.3838e+02, ..., -3.6760e-01,\n"," -5.7795e+01, -1.2785e+02]])\n"]}],"source":["c_ours = torch.from_numpy(c.numpy())\n","print(\"ours:\", c_ours)\n","print(\"ref: \", c_ref)\n","assert torch.allclose(c_ours, c_ref, rtol=1e-03, atol=1e-03)"]}],"metadata":{"kernelspec":{"display_name":"eda","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.8.18"}},"nbformat":4,"nbformat_minor":2} diff --git a/impl/torch/functions/functions_sparse.cpp b/impl/torch/functions/functions_sparse.cpp index 418791703..bfdfabb3f 100644 --- a/impl/torch/functions/functions_sparse.cpp +++ b/impl/torch/functions/functions_sparse.cpp @@ -51,5 +51,29 @@ diopiError_t diopiSpMM(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiC return diopiErrorOccurred; } + +extern "C" diopiError_t diopiSpConv(diopiContextHandle_t ctx, diopiTensorHandle_t out_feat, diopiTensorHandle_t in_feat, + diopicTensorHandle_t kernel, diopiTensorHandle_t neighbor_map,const int sum_nnz, + diopiTensorHandle_t neighbor_address, diopiTensorHandle_t q_neighbor_address, const int output_size, + const int qsum_nnz, const bool transpose, const bool allow_tf32, const bool allow_fp16 ) { + + impl::aten::setCurStream(ctx); + + auto atIn_feat = impl::aten::buildATen(in_feat); + auto atOut_feat = impl::aten::buildATen(out_feat); + auto atKernel = impl::aten::buildATen(kernel); + auto atNeighbor_map = impl::aten::buildATen(neighbor_map); + auto atNeighbor_address = impl::aten::buildATen(neighbor_address); + auto atQ_neighbor_address = impl::aten::buildATen(q_neighbor_address); + atOut_feat.zero_(); + + sparse::ops::conv_forward_fetch_on_demand_cuda(atIn_feat, atOut_feat, atKernel, atNeighbor_map, sum_nnz, + atNeighbor_address,atQ_neighbor_address,output_size, + qsum_nnz,transpose,allow_tf32,allow_fp16); + + return diopiSuccess; + +} + } // namespace cuda } // namespace impl diff --git a/impl/torch/functions/functions_sparse/spconv.cu b/impl/torch/functions/functions_sparse/spconv.cu new file mode 100644 index 000000000..3480b20e1 --- /dev/null +++ b/impl/torch/functions/functions_sparse/spconv.cu @@ -0,0 +1,2266 @@ +/* +Please consider citing the following paper when using the code: + +@inproceedings{hong2023pcengine, + title={{Exploiting Hardware Utilization and Adaptive Dataflow for Efficient Sparse Convolution in 3D Point Clouds}}, + author={Hong, Ke and Yu, Zhongming and Dai, Guohao and Yang, Xinhao and Lian, Yaoxiu and Liu, Zehao and Xu, Ningyi and Wang, Yu}, + booktitle={Sixth Conference on Machine Learning and Systems (MLSys)}, + year={2023} +} +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#if __CUDA_ARCH__ >= 700 +#include +#endif + + +// #include "convolution_forward_fetch_on_demand_cuda.h" + +namespace sparse{ +namespace ops { +#define DIV_UP(x, y) ((x) + (y) - 1) / (y) + +// kernels employed in PCEngine [Fetch-on-Demand] +// device function to indicate the weight index in fetch-on-demand gemms +__device__ __forceinline__ int binary_search( + const int *S_csrRowPtr, const int eid, + const int start, const int end) { + + int lo = start, hi = end; + if (lo == hi){ + return lo; + } + while (lo < hi) { + int mid = (lo + hi) >> 1; + if (__ldg(S_csrRowPtr + mid) <= eid) { + lo = mid + 1; + } else { + hi = mid; + } + } + if (__ldg(S_csrRowPtr + hi) <= eid) { + return hi; + } else { + return hi - 1; + } +} + + +/* +BLOCK_SIZE = 32, N_LOOP = 4, SKEW = 8, blockDim.x = 8, blockDim.y = 32 +*/ +template +__global__ void fetch_on_demand_gemm_fp32( + const int *__restrict__ kpos, + const int *__restrict__ qkpos, + const int k_vol, + const int c_in, + const int c_out, + const float *__restrict__ in_f, + const float *__restrict__ kw, + float *out_f, + const int *imap, + const int *omap) { + + // Block index + const int bx = blockIdx.x; + const int by = blockIdx.y; + + // Thread index + const int tx = threadIdx.x; + const int ty = threadIdx.y; + const int ctx = tx << 2; + + // Weight index + const int widx = binary_search(qkpos, by * N_LOOP * BLOCK_SIZE, 0, k_vol); + const float *kw_ptr = &kw[widx * c_in * c_out]; + + // Coordinate. x is for rows, y is for columns. + const int cx = BLOCK_SIZE * bx + ctx; + const int y = BLOCK_SIZE * N_LOOP * by + ty + - __ldg(&qkpos[widx]) + __ldg(&kpos[widx]); + + // Csub is used to store the element of the block sub-matrix + // that is computed by the thread + float Csub[N_LOOP][4] = {0.0f}; + float padding[4] = {0.0f}; + + // Declaration of the shared memory array As used to + // store the sub-matrix of A + __shared__ float As[N_LOOP][BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Declaration of the shared memory array Bs used to + // store the sub-matrix of B + __shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Loop over all the sub-matrices of A and B + // required to compute the block sub-matrix + for (int s = 0; s < c_in; s += BLOCK_SIZE) { + + // Load the matrices from device memory + // to shared memory; each thread loads + // one element of each matrix + + // Kernel weight to Bs + *((float4*)(&Bs[ty][ctx])) = ((s + ty) < c_in && cx < c_out) ? + *((float4*)(kw_ptr + c_out * (s + ty) + cx)) : + *((float4*)(&padding[0])); + + // Input feature to As + for (int n = 0; n < N_LOOP; n++){ + + int y_temp = y + n * BLOCK_SIZE; + + // The thread deals with the x-th channel of the y-th output + int in_row = y_temp < __ldg(&kpos[widx + 1]) ? imap[y_temp] : -1; + + *((float4*)(&As[n][ty][ctx])) = ((s + ctx) < c_in && in_row > -1) ? + *((float4*)(&in_f[c_in * in_row + s + ctx])) : + *((float4*)(&padding[0])); + } + + // Synchronize to make sure the matrices are loaded + __syncthreads(); + + // Multiply the two matrices together; + // each thread computes one element + // of the block sub-matrix +#pragma unroll + for (int n = 0; n < N_LOOP; n++){ +#pragma unroll + for (int k = 0; k < BLOCK_SIZE; ++k) { + float Ast = As[n][ty][k]; +#pragma unroll + for (int c = 0; c < 4; c++){ + Csub[n][c] += Ast * Bs[k][ctx + c]; + } + } + } + + // Synchronize to make sure that the preceding + // computation is done before loading two new + // sub-matrices of A and B in the next iteration + __syncthreads(); + } + + // Write the block sub-matrix to device memory; + // each thread writes one element +#pragma unroll + for (int n = 0; n < N_LOOP; n++){ + int y_temp = y + n * BLOCK_SIZE; + int out_row = y_temp < __ldg(&kpos[widx + 1]) ? omap[y_temp] : -1; + if (out_row > -1 && cx < c_out){ +#pragma unroll + for (int c = 0; c < 4; c++){ + atomicAdd(&out_f[c_out * out_row + cx + c], Csub[n][c]); + } + } + } +} + + +/* +BLOCK_SIZE = 32, N_LOOP = 4, SKEW = 8, blockDim.x = 8, blockDim.y = 32 +*/ +template +__global__ void fetch_on_demand_gemm_no_fusion_fp32( + const int knnz, + const int c_in, + const int c_out, + const float *__restrict__ in_f, + const float *__restrict__ kw, + float *out_f, + const int *imap, + const int *omap) { + + // Block index + const int bx = blockIdx.x; + const int by = blockIdx.y; + + // Thread index + const int tx = threadIdx.x; + const int ty = threadIdx.y; + const int ctx = tx << 2; + + // Weight index + // const int widx = binary_search(qkpos, by * N_LOOP * BLOCK_SIZE, 0, k_vol); + // const float *kw_ptr = &kw[widx * c_in * c_out]; + const float *kw_ptr = &kw[0]; + + // Coordinate. x is for rows, y is for columns. + const int cx = BLOCK_SIZE * bx + ctx; + const int y = BLOCK_SIZE * N_LOOP * by + ty; + + // Csub is used to store the element of the block sub-matrix + // that is computed by the thread + float Csub[N_LOOP][4] = {0.0f}; + float padding[4] = {0.0f}; + + // Declaration of the shared memory array As used to + // store the sub-matrix of A + __shared__ float As[N_LOOP][BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Declaration of the shared memory array Bs used to + // store the sub-matrix of B + __shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Loop over all the sub-matrices of A and B + // required to compute the block sub-matrix + for (int s = 0; s < c_in; s += BLOCK_SIZE) { + + // Load the matrices from device memory + // to shared memory; each thread loads + // one element of each matrix + + // Kernel weight to Bs + *((float4*)(&Bs[ty][ctx])) = ((s + ty) < c_in && cx < c_out) ? + *((float4*)(kw_ptr + c_out * (s + ty) + cx)) : + *((float4*)(&padding[0])); + + // Input feature to As + for (int n = 0; n < N_LOOP; n++){ + + int y_temp = y + n * BLOCK_SIZE; + + // The thread deals with the x-th channel of the y-th output + int in_row = y_temp < knnz ? imap[y_temp] : -1; + + *((float4*)(&As[n][ty][ctx])) = ((s + ctx) < c_in && in_row > -1) ? + *((float4*)(&in_f[c_in * in_row + s + ctx])) : + *((float4*)(&padding[0])); + } + + // Synchronize to make sure the matrices are loaded + __syncthreads(); + + // Multiply the two matrices together; + // each thread computes one element + // of the block sub-matrix +#pragma unroll + for (int n = 0; n < N_LOOP; n++){ +#pragma unroll + for (int k = 0; k < BLOCK_SIZE; ++k) { + float Ast = As[n][ty][k]; +#pragma unroll + for (int c = 0; c < 4; c++){ + Csub[n][c] += Ast * Bs[k][ctx + c]; + } + } + } + + // Synchronize to make sure that the preceding + // computation is done before loading two new + // sub-matrices of A and B in the next iteration + __syncthreads(); + } + + // Write the block sub-matrix to device memory; + // each thread writes one element +#pragma unroll + for (int n = 0; n < N_LOOP; n++){ + int y_temp = y + n * BLOCK_SIZE; + int out_row = y_temp < knnz ? omap[y_temp] : -1; + if (out_row > -1 && cx < c_out){ +#pragma unroll + for (int c = 0; c < 4; c++){ + // atomicAdd(&out_f[c_out * out_row + cx + c], Csub[n][c]); + out_f[c_out * out_row + cx + c] += Csub[n][c]; + } + } + } +} + + +/* +BLOCK_SIZE = 16, N_LOOP = 8, SKEW = 8, blockDim.x = 4, blockDim.y = 16 +*/ +template +__global__ void fetch_on_demand_gemm_fp32_once( + const int *__restrict__ kpos, + const int *__restrict__ qkpos, + const int k_vol, + const int c_in, + const int c_out, + const float *__restrict__ in_f, + const float *__restrict__ kw, + float *out_f, + const int *imap, + const int *omap) { + + // Block index + const int bx = blockIdx.x; + const int by = blockIdx.y; + + // Thread index + const int tx = threadIdx.x; + const int ty = threadIdx.y; + const int ctx = tx << 2; + + // Weight index + const int widx = binary_search(qkpos, by * N_LOOP * BLOCK_SIZE, 0, k_vol); + const float *kw_ptr = &kw[widx * c_in * c_out]; + + // Coordinate. x is for rows, y is for columns. + const int cx = BLOCK_SIZE * bx + ctx; + const int y = BLOCK_SIZE * N_LOOP * by + ty + - __ldg(&qkpos[widx]) + __ldg(&kpos[widx]); + + // Csub is used to store the element of the block sub-matrix + // that is computed by the thread + float Csub[N_LOOP][4] = {0.0f}; + float padding[4] = {0.0f}; + + // Declaration of the shared memory array As used to + // store the sub-matrix of A + __shared__ float As[N_LOOP][BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Declaration of the shared memory array Bs used to + // store the sub-matrix of B + __shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Loop over all the sub-matrices of A and B + // required to compute the block sub-matrix + // In "loop once" version, s = 0 + // for (int s = 0; s < c_in; s += BLOCK_SIZE) { + + // Load the matrices from device memory + // to shared memory; each thread loads + // one element of each matrix + + // Kernel weight to Bs + *((float4*)(&Bs[ty][ctx])) = ((ty) < c_in && cx < c_out) ? + *((float4*)(kw_ptr + c_out * (ty) + cx)) : + *((float4*)(&padding[0])); + + // Input feature to As + for (int n = 0; n < N_LOOP; n++){ + + int y_temp = y + n * BLOCK_SIZE; + + // The thread deals with the x-th channel of the y-th output + int in_row = y_temp < __ldg(&kpos[widx + 1]) ? imap[y_temp] : -1; + + *((float4*)(&As[n][ty][ctx])) = ((ctx) < c_in && in_row > -1) ? + *((float4*)(&in_f[c_in * in_row + ctx])) : + *((float4*)(&padding[0])); + } + + // Synchronize to make sure the matrices are loaded + __syncthreads(); + + // Multiply the two matrices together; + // each thread computes one element + // of the block sub-matrix +#pragma unroll + for (int n = 0; n < N_LOOP; n++){ +#pragma unroll + for (int k = 0; k < c_in; ++k) { + float Ast = As[n][ty][k]; +#pragma unroll + for (int c = 0; c < 4; c++){ + Csub[n][c] += Ast * Bs[k][ctx + c]; + } + } + int y_temp = y + n * BLOCK_SIZE; + int out_row = y_temp < __ldg(&kpos[widx + 1]) ? omap[y_temp] : -1; + if (out_row > -1 && cx < c_out){ +#pragma unroll + for (int c = 0; c < 4; c++){ + atomicAdd(&out_f[c_out * out_row + cx + c], Csub[n][c]); + } + } + } +} + +/* +BLOCK_SIZE = 16, N_LOOP = 8, SKEW = 8, +blockDim.x = 8, blockDim.y = 16 +*/ +template +__global__ void fetch_on_demand_gemm_fp32_2( + const int *__restrict__ kpos, + const int *__restrict__ qkpos, + const int k_vol, + const int c_in, + const int c_out, + const float *__restrict__ in_f, + const float *__restrict__ kw, + float *out_f, + const int *imap, + const int *omap) { + + // Block index + const int bx = blockIdx.x; + const int by = blockIdx.y; + + // Thread index + const int tx = threadIdx.x; + const int ty = threadIdx.y; + const int ctx = tx << 1; + + // Weight index + const int widx = binary_search(qkpos, by * N_LOOP * BLOCK_SIZE, 0, k_vol); + const float *kw_ptr = &kw[widx * c_in * c_out]; + + // Coordinate. x is for rows, y is for columns. + const int cx = BLOCK_SIZE * bx + ctx; + const int y = BLOCK_SIZE * N_LOOP * by + ty + - __ldg(&qkpos[widx]) + __ldg(&kpos[widx]); + + // Csub is used to store the element of the block sub-matrix + // that is computed by the thread + float Csub[N_LOOP][2] = {0.0f}; + float padding[2] = {0.0f}; + + // Declaration of the shared memory array As used to + // store the sub-matrix of A + __shared__ float As[N_LOOP][BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Declaration of the shared memory array Bs used to + // store the sub-matrix of B + __shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Loop over all the sub-matrices of A and B + // required to compute the block sub-matrix + for (int s = 0; s < c_in; s += BLOCK_SIZE) { + + // Load the matrices from device memory + // to shared memory; each thread loads + // one element of each matrix + + // Kernel weight to Bs + *((float2*)(&Bs[ty][ctx])) = ((s + ty) < c_in && cx < c_out) ? + *((float2*)(kw_ptr + c_out * (s + ty) + cx)) : + *((float2*)(&padding[0])); + + // Input feature to As + for (int n = 0; n < N_LOOP; n++){ + + int y_temp = y + n * BLOCK_SIZE; + + // The thread deals with the x-th channel of the y-th output + int in_row = y_temp < __ldg(&kpos[widx + 1]) ? imap[y_temp] : -1; + + *((float2*)(&As[n][ty][ctx])) = ((s + ctx) < c_in && in_row > -1) ? + *((float2*)(&in_f[c_in * in_row + s + ctx])) : + *((float2*)(&padding[0])); + } + + // Synchronize to make sure the matrices are loaded + __syncthreads(); + + // Multiply the two matrices together; + // each thread computes one element + // of the block sub-matrix +#pragma unroll + for (int n = 0; n < N_LOOP; n++){ +#pragma unroll + for (int k = 0; k < BLOCK_SIZE; ++k) { + float Ast = As[n][ty][k]; +#pragma unroll + for (int c = 0; c < 2; c++){ + Csub[n][c] += Ast * Bs[k][ctx + c]; + } + } + } + + // Synchronize to make sure that the preceding + // computation is done before loading two new + // sub-matrices of A and B in the next iteration + __syncthreads(); + } + + // Write the block sub-matrix to device memory; + // each thread writes one element +#pragma unroll + for (int n = 0; n < N_LOOP; n++){ + int y_temp = y + n * BLOCK_SIZE; + int out_row = y_temp < __ldg(&kpos[widx + 1]) ? omap[y_temp] : -1; + if (out_row > -1 && cx < c_out){ +#pragma unroll + for (int c = 0; c < 2; c++){ + atomicAdd(&out_f[c_out * out_row + cx + c], Csub[n][c]); + } + } + } +} + + +/* +BLOCK_SIZE = 16, N_LOOP = 4, SKEW = 8, +blockDim.x = 16, blockDim.y = 16 +*/ +template +__global__ void fetch_on_demand_gemm_fp32_1( + const int *__restrict__ kpos, + const int *__restrict__ qkpos, + const int k_vol, + const int c_in, + const int c_out, + const float *__restrict__ in_f, + const float *__restrict__ kw, + float *out_f, + const int *imap, + const int *omap) { + + // Block index + const int bx = blockIdx.x; + const int by = blockIdx.y; + + // Thread index + const int tx = threadIdx.x; + const int ty = threadIdx.y; + // const int ctx = tx; + + // Weight index + const int widx = binary_search(qkpos, by * N_LOOP * BLOCK_SIZE, 0, k_vol); + const float *kw_ptr = &kw[widx * c_in * c_out]; + + // Coordinate. x is for rows, y is for columns. + const int cx = BLOCK_SIZE * bx + tx; + const int y = BLOCK_SIZE * N_LOOP * by + ty + - __ldg(&qkpos[widx]) + __ldg(&kpos[widx]); + + // Csub is used to store the element of the block sub-matrix + // that is computed by the thread + float Csub[N_LOOP] = {0.0f}; + float padding = 0.0f; + + // Declaration of the shared memory array As used to + // store the sub-matrix of A + __shared__ float As[N_LOOP][BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Declaration of the shared memory array Bs used to + // store the sub-matrix of B + __shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Loop over all the sub-matrices of A and B + // required to compute the block sub-matrix + for (int s = 0; s < c_in; s += BLOCK_SIZE) { + + // Load the matrices from device memory + // to shared memory; each thread loads + // one element of each matrix + + // Kernel weight to Bs + Bs[ty][tx] = ((s + ty) < c_in && cx < c_out) ? + *(kw_ptr + c_out * (s + ty) + cx) : + padding; + + // Input feature to As + for (int n = 0; n < N_LOOP; n++){ + + int y_temp = y + n * BLOCK_SIZE; + + // The thread deals with the x-th channel of the y-th output + int in_row = y_temp < __ldg(&kpos[widx + 1]) ? imap[y_temp] : -1; + + As[n][ty][tx] = ((s + tx) < c_in && in_row > -1) ? + in_f[c_in * in_row + s + tx] : + padding; + } + + // Synchronize to make sure the matrices are loaded + __syncthreads(); + + // Multiply the two matrices together; + // each thread computes one element + // of the block sub-matrix +#pragma unroll + for (int n = 0; n < N_LOOP; n++){ +#pragma unroll + for (int k = 0; k < BLOCK_SIZE; ++k) { + // float Ast = As[n][ty][k]; + // for (int c = 0; c < 2; c++){ + Csub[n] += As[n][ty][k] * Bs[k][tx]; + // } + } + } + + // Synchronize to make sure that the preceding + // computation is done before loading two new + // sub-matrices of A and B in the next iteration + __syncthreads(); + } + + // Write the block sub-matrix to device memory; + // each thread writes one element +#pragma unroll + for (int n = 0; n < N_LOOP; n++){ + int y_temp = y + n * BLOCK_SIZE; + int out_row = y_temp < __ldg(&kpos[widx + 1]) ? omap[y_temp] : -1; + if (out_row > -1 && cx < c_out){ + // for (int c = 0; c < 2; c++){ + atomicAdd(&out_f[c_out * out_row + cx], Csub[n]); + // } + } + } +} + + +/* +BLOCK_SIZE = 16, N_LOOP = 4, SKEW = 8, +blockDim.x = 16, blockDim.y = 16 +*/ +template +__global__ void fetch_on_demand_gemm_no_fusion_fp32_1( + const int knnz, + const int c_in, + const int c_out, + const float *__restrict__ in_f, + const float *__restrict__ kw, + float *out_f, + const int *imap, + const int *omap) { + + // Block index + const int bx = blockIdx.x; + const int by = blockIdx.y; + + // Thread index + const int tx = threadIdx.x; + const int ty = threadIdx.y; + // const int ctx = tx; + + // Weight index + // const int widx = binary_search(qkpos, by * N_LOOP * BLOCK_SIZE, 0, k_vol); + const float *kw_ptr = &kw[0]; + + // Coordinate. x is for rows, y is for columns. + const int cx = BLOCK_SIZE * bx + tx; + const int y = BLOCK_SIZE * N_LOOP * by + ty; + + // Csub is used to store the element of the block sub-matrix + // that is computed by the thread + float Csub[N_LOOP] = {0.0f}; + float padding = 0.0f; + + // Declaration of the shared memory array As used to + // store the sub-matrix of A + __shared__ float As[N_LOOP][BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Declaration of the shared memory array Bs used to + // store the sub-matrix of B + __shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Loop over all the sub-matrices of A and B + // required to compute the block sub-matrix + for (int s = 0; s < c_in; s += BLOCK_SIZE) { + + // Load the matrices from device memory + // to shared memory; each thread loads + // one element of each matrix + + // Kernel weight to Bs + Bs[ty][tx] = ((s + ty) < c_in && cx < c_out) ? + *(kw_ptr + c_out * (s + ty) + cx) : + padding; + + // Input feature to As + for (int n = 0; n < N_LOOP; n++){ + + int y_temp = y + n * BLOCK_SIZE; + + // The thread deals with the x-th channel of the y-th output + int in_row = y_temp < knnz ? imap[y_temp] : -1; + + As[n][ty][tx] = ((s + tx) < c_in && in_row > -1) ? + in_f[c_in * in_row + s + tx] : + padding; + } + + // Synchronize to make sure the matrices are loaded + __syncthreads(); + + // Multiply the two matrices together; + // each thread computes one element + // of the block sub-matrix +#pragma unroll + for (int n = 0; n < N_LOOP; n++){ +#pragma unroll + for (int k = 0; k < BLOCK_SIZE; ++k) { + // float Ast = As[n][ty][k]; + // for (int c = 0; c < 2; c++){ + Csub[n] += As[n][ty][k] * Bs[k][tx]; + // } + } + } + + // Synchronize to make sure that the preceding + // computation is done before loading two new + // sub-matrices of A and B in the next iteration + __syncthreads(); + } + + // Write the block sub-matrix to device memory; + // each thread writes one element +#pragma unroll + for (int n = 0; n < N_LOOP; n++){ + int y_temp = y + n * BLOCK_SIZE; + int out_row = y_temp < knnz ? omap[y_temp] : -1; + if (out_row > -1 && cx < c_out){ + // for (int c = 0; c < 2; c++){ + // atomicAdd(&out_f[c_out * out_row + cx], Csub[n]); + out_f[c_out * out_row + cx] += Csub[n]; + // } + } + } +} + + +/* +BLOCK_SIZE = 32, N_LOOP = 4, SKEW = 8, blockDim.x = 8, blockDim.y = 32 +*/ +template +__global__ void fetch_on_demand_gemm_fp16_4_once( + const int *__restrict__ kpos, + const int *__restrict__ qkpos, + const int k_vol, + const int c_in, + const int c_out, + const half *__restrict__ in_f, + const half *__restrict__ kw, + half *out_f, + const int *imap, + const int *omap) { +# if __CUDA_ARCH__ >= 700 + // Block index + const int bx = blockIdx.x; + const int by = blockIdx.y; + + // Thread index + const int tx = threadIdx.x; + const int ty = threadIdx.y; + const int ctx = tx << 2; + + // Weight index + const int widx = binary_search(qkpos, by * N_LOOP * BLOCK_SIZE, 0, k_vol); + const half *kw_ptr = &kw[widx * c_in * c_out]; + + // Coordinate. x is for rows, y is for columns. + const int cx = BLOCK_SIZE * bx + ctx; + const int y = BLOCK_SIZE * N_LOOP * by + ty + - __ldg(&qkpos[widx]) + __ldg(&kpos[widx]); + + // Csub is used to store the element of the block sub-matrix + // that is computed by the thread + half Csub[N_LOOP][4] = {__float2half(0.0f)}; + half padding[4] = {__float2half(0.0f)}; + + // Declaration of the shared memory array As used to + // store the sub-matrix of A + __shared__ half As[N_LOOP][BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Declaration of the shared memory array Bs used to + // store the sub-matrix of B + __shared__ half Bs[BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Loop over all the sub-matrices of A and B + // required to compute the block sub-matrix + // In "loop once" version, s = 0 + // for (int s = 0; s < c_in; s += BLOCK_SIZE) { + + // Kernel weight to Bs + *((float2*)(&Bs[ty][ctx])) = (ty < c_in && cx < c_out) ? + *((float2*)(kw_ptr + c_out * ty + cx)) : + *((float2*)(&padding[0])); + + int y_temp = y; + // Input feature to As + for (int n = 0; n < N_LOOP; n++){ + + // The thread deals with the x-th channel of the y-th output + int in_row = y_temp < __ldg(&kpos[widx + 1]) ? imap[y_temp] : -1; + + *((float2*)(&As[n][ty][ctx])) = (ctx < c_in && in_row > -1) ? + *((float2*)(&in_f[c_in * in_row + ctx])) : + *((float2*)(&padding[0])); + + y_temp += BLOCK_SIZE; + } + + // Synchronize to make sure the matrices are loaded + __syncthreads(); + + // Multiply the two matrices together; + // each thread computes one element + // of the block sub-matrix +#pragma unroll + for (int n = 0; n < N_LOOP; n++){ +#pragma unroll + for (int k = 0; k < c_in; ++k){ + half Ast = As[n][ty][k]; +#pragma unroll + for (int c = 0; c < 4; c++){ + Csub[n][c] = __hfma(Ast, Bs[k][ctx + c], Csub[n][c]); + } + } + + int y_temp = y + n * BLOCK_SIZE; + int out_row = y_temp < __ldg(&kpos[widx + 1]) ? omap[y_temp] : -1; + if (out_row > -1 && cx < c_out){ +#pragma unroll + for (int c = 0; c < 4; c++){ + atomicAdd(&out_f[c_out * out_row + cx + c], Csub[n][c]); + } + } + } +#else + #pragma message("FP16 kernels will not be compiled.") +#endif +} + + +/* +BLOCK_SIZE = 16, N_LOOP = 8, SKEW = 8, blockDim.x = 8, blockDim.y = 16 +*/ +template +__global__ void fetch_on_demand_gemm_fp16_2( + const int *__restrict__ kpos, + const int *__restrict__ qkpos, + const int k_vol, + const int c_in, + const int c_out, + const half *__restrict__ in_f, + const half *__restrict__ kw, + half *out_f, + const int *imap, + const int *omap) { +# if __CUDA_ARCH__ >= 700 + // Block index + const int bx = blockIdx.x; + const int by = blockIdx.y; + + // Thread index + const int tx = threadIdx.x; + const int ty = threadIdx.y; + const int ctx = tx << 1; + + // Weight index + const int widx = binary_search(qkpos, by * N_LOOP * BLOCK_SIZE, 0, k_vol); + const half *kw_ptr = &kw[widx * c_in * c_out]; + + // Coordinate. x is for rows, y is for columns. + const int cx = BLOCK_SIZE * bx + ctx; + const int y = BLOCK_SIZE * N_LOOP * by + ty + - __ldg(&qkpos[widx]) + __ldg(&kpos[widx]); + + // Csub is used to store the element of the block sub-matrix + // that is computed by the thread + half Csub[N_LOOP][2] = {__float2half(0.0f)}; + half padding[2] = {__float2half(0.0f)}; + + // Declaration of the shared memory array As used to + // store the sub-matrix of A + __shared__ half As[N_LOOP][BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Declaration of the shared memory array Bs used to + // store the sub-matrix of B + __shared__ half Bs[BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Loop over all the sub-matrices of A and B + // required to compute the block sub-matrix + for (int s = 0; s < c_in; s += BLOCK_SIZE) { + + // Load the matrices from device memory + // to shared memory; each thread loads + // one element of each matrix + + // Kernel weight to Bs + *((float*)(&Bs[ty][ctx])) = ((s + ty) < c_in && cx < c_out) ? + *((float*)(kw_ptr + c_out * (s + ty) + cx)) : + *((float*)(&padding[0])); + + // Input feature to As + for (int n = 0; n < N_LOOP; n++){ + + int y_temp = y + n * BLOCK_SIZE; + + // The thread deals with the x-th channel of the y-th output + int in_row = y_temp < __ldg(&kpos[widx + 1]) ? imap[y_temp] : -1; + + *((float*)(&As[n][ty][ctx])) = ((s + ctx) < c_in && in_row > -1) ? + *((float*)(&in_f[c_in * in_row + s + ctx])) : + *((float*)(&padding[0])); + } + + // Synchronize to make sure the matrices are loaded + __syncthreads(); + + // Multiply the two matrices together; + // each thread computes one element + // of the block sub-matrix +#pragma unroll + for (int n = 0; n < N_LOOP; n++){ +#pragma unroll + for (int k = 0; k < BLOCK_SIZE; ++k) { + half Ast = As[n][ty][k]; +#pragma unroll + for (int c = 0; c < 2; c++){ + Csub[n][c] = __hfma(Ast, Bs[k][ctx + c], Csub[n][c]); + } + } + } + + // Synchronize to make sure that the preceding + // computation is done before loading two new + // sub-matrices of A and B in the next iteration + __syncthreads(); + } + + // Write the block sub-matrix to device memory; + // each thread writes one element +#pragma unroll + for (int n = 0; n < N_LOOP; n++){ + int y_temp = y + n * BLOCK_SIZE; + int out_row = y_temp < __ldg(&kpos[widx + 1]) ? omap[y_temp] : -1; + if (out_row > -1 && cx < c_out){ +#pragma unroll + for (int c = 0; c < 2; c++){ + atomicAdd(&out_f[c_out * out_row + cx + c], Csub[n][c]); + } + } + } +#else + #pragma message("TF32 kernels will not be compiled.") +#endif +} + + +/* +BLOCK_SIZE = 16, N_LOOP = 4, SKEW = 8, blockDim.x = 16, blockDim.y = 16 +*/ +template +__global__ void fetch_on_demand_gemm_fp16_1( + const int *__restrict__ kpos, + const int *__restrict__ qkpos, + const int k_vol, + const int c_in, + const int c_out, + const half *__restrict__ in_f, + const half *__restrict__ kw, + half *out_f, + const int *imap, + const int *omap) { +# if __CUDA_ARCH__ >= 700 + // Block index + const int bx = blockIdx.x; + const int by = blockIdx.y; + + // Thread index + const int tx = threadIdx.x; + const int ty = threadIdx.y; + // const int ctx = tx << 1; + + // Weight index + const int widx = binary_search(qkpos, by * N_LOOP * BLOCK_SIZE, 0, k_vol); + const half *kw_ptr = &kw[widx * c_in * c_out]; + + // Coordinate. x is for rows, y is for columns. + const int cx = BLOCK_SIZE * bx + tx; + const int y = BLOCK_SIZE * N_LOOP * by + ty + - __ldg(&qkpos[widx]) + __ldg(&kpos[widx]); + + // Csub is used to store the element of the block sub-matrix + // that is computed by the thread + half Csub[N_LOOP] = {__float2half(0.0f)}; + half padding = __float2half(0.0f); + + // Declaration of the shared memory array As used to + // store the sub-matrix of A + __shared__ half As[N_LOOP][BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Declaration of the shared memory array Bs used to + // store the sub-matrix of B + __shared__ half Bs[BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Loop over all the sub-matrices of A and B + // required to compute the block sub-matrix + for (int s = 0; s < c_in; s += BLOCK_SIZE) { + + // Load the matrices from device memory + // to shared memory; each thread loads + // one element of each matrix + + // Kernel weight to Bs + Bs[ty][tx] = ((s + ty) < c_in && cx < c_out) ? + *(kw_ptr + c_out * (s + ty) + cx) : + padding; + + // Input feature to As + for (int n = 0; n < N_LOOP; n++){ + + int y_temp = y + n * BLOCK_SIZE; + + // The thread deals with the x-th channel of the y-th output + int in_row = y_temp < __ldg(&kpos[widx + 1]) ? imap[y_temp] : -1; + + As[n][ty][tx] = ((s + tx) < c_in && in_row > -1) ? + in_f[c_in * in_row + s + tx] : + padding; + } + + // Synchronize to make sure the matrices are loaded + __syncthreads(); + + // Multiply the two matrices together; + // each thread computes one element + // of the block sub-matrix +#pragma unroll + for (int n = 0; n < N_LOOP; n++){ +#pragma unroll + for (int k = 0; k < BLOCK_SIZE; ++k) { + // half Ast = As[n][ty][k]; + // for (int c = 0; c < 2; c++){ + Csub[n] = __hfma(As[n][ty][k], Bs[k][tx], Csub[n]); + // } + } + } + + // Synchronize to make sure that the preceding + // computation is done before loading two new + // sub-matrices of A and B in the next iteration + __syncthreads(); + } + + // Write the block sub-matrix to device memory; + // each thread writes one element +#pragma unroll + for (int n = 0; n < N_LOOP; n++){ + int y_temp = y + n * BLOCK_SIZE; + int out_row = y_temp < __ldg(&kpos[widx + 1]) ? omap[y_temp] : -1; + if (out_row > -1 && cx < c_out){ + // for (int c = 0; c < 2; c++){ + atomicAdd(&out_f[c_out * out_row + cx], Csub[n]); + // } + } + } +#else + #pragma message("FP16 kernels will not be compiled.") +#endif +} + + +/* +BLOCK_SIZE = 16, N_LOOP = 4, SKEW = 8, blockDim.x = 16, blockDim.y = 16 +*/ +template +__global__ void fetch_on_demand_gemm_no_fusion_fp16_1( + const int knnz, + const int c_in, + const int c_out, + const half *__restrict__ in_f, + const half *__restrict__ kw, + half *out_f, + const int *imap, + const int *omap) { + +#if __CUDA_ARCH__ >= 700 + + // Block index + const int bx = blockIdx.x; + const int by = blockIdx.y; + + // Thread index + const int tx = threadIdx.x; + const int ty = threadIdx.y; + // const int ctx = tx << 1; + + // Weight index + // const int widx = binary_search(qkpos, by * N_LOOP * BLOCK_SIZE, 0, k_vol); + const half *kw_ptr = &kw[0]; + + // Coordinate. x is for rows, y is for columns. + const int cx = BLOCK_SIZE * bx + tx; + const int y = BLOCK_SIZE * N_LOOP * by + ty; + + // Csub is used to store the element of the block sub-matrix + // that is computed by the thread + half Csub[N_LOOP] = {__float2half(0.0f)}; + half padding = __float2half(0.0f); + + // Declaration of the shared memory array As used to + // store the sub-matrix of A + __shared__ half As[N_LOOP][BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Declaration of the shared memory array Bs used to + // store the sub-matrix of B + __shared__ half Bs[BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Loop over all the sub-matrices of A and B + // required to compute the block sub-matrix + for (int s = 0; s < c_in; s += BLOCK_SIZE) { + + // Load the matrices from device memory + // to shared memory; each thread loads + // one element of each matrix + + // Kernel weight to Bs + Bs[ty][tx] = ((s + ty) < c_in && cx < c_out) ? + *(kw_ptr + c_out * (s + ty) + cx) : + padding; + + // Input feature to As + for (int n = 0; n < N_LOOP; n++){ + + int y_temp = y + n * BLOCK_SIZE; + + // The thread deals with the x-th channel of the y-th output + int in_row = y_temp < knnz ? imap[y_temp] : -1; + + As[n][ty][tx] = ((s + tx) < c_in && in_row > -1) ? + in_f[c_in * in_row + s + tx] : + padding; + } + + // Synchronize to make sure the matrices are loaded + __syncthreads(); + + // Multiply the two matrices together; + // each thread computes one element + // of the block sub-matrix +#pragma unroll + for (int n = 0; n < N_LOOP; n++){ +#pragma unroll + for (int k = 0; k < BLOCK_SIZE; ++k) { + // half Ast = As[n][ty][k]; + // for (int c = 0; c < 2; c++){ + Csub[n] = __hfma(As[n][ty][k], Bs[k][tx], Csub[n]); + // } + } + } + + // Synchronize to make sure that the preceding + // computation is done before loading two new + // sub-matrices of A and B in the next iteration + __syncthreads(); + } + + // Write the block sub-matrix to device memory; + // each thread writes one element +#pragma unroll + for (int n = 0; n < N_LOOP; n++){ + int y_temp = y + n * BLOCK_SIZE; + int out_row = y_temp < knnz ? omap[y_temp] : -1; + if (out_row > -1 && cx < c_out){ + // for (int c = 0; c < 2; c++){ + // atomicAdd(&out_f[c_out * out_row + cx], Csub[n]); + out_f[c_out * out_row + cx] = + __hadd(out_f[c_out * out_row + cx], Csub[n]); + // } + } + } +#else + #pragma message("FP16 kernels will not be compiled.") +#endif +} + + +// kernels using tensor cores +// using namespace nvcuda; +/* +BLOCK_SIZE = 32, N_LOOP = 4, SKEW = 8, M = 16, K = 8, N = 16, +MS = 2, NS = 2, WS = 4 = MS x NS +blockDim.x = 8, blockDim.y = 32 +*/ +template +__global__ void fetch_on_demand_gemm_tf32( + const int *__restrict__ kpos, + const int *__restrict__ qkpos, + const int k_vol, + const int c_in, + const int c_out, + const float *__restrict__ in_f, + const float *__restrict__ kw, + float *out_f, + const int *imap, + const int *omap) { +#if __CUDA_ARCH__ >= 800 + + // Block index + const int bx = blockIdx.x; + const int by = blockIdx.y; + + // Thread index + const int tx = threadIdx.x; + const int ty = threadIdx.y; + const int ctx = tx << 2; + const int tid = ty * blockDim.x + tx; + + // Warp index + const int warpId = tid / 32; + const int laneId = tid % 32; + const int warp_row = warpId / NS; + const int warp_col = warpId % NS; + + // Weight index + const int widx = binary_search(qkpos, by * N_LOOP * BLOCK_SIZE, 0, k_vol); + const float *kw_ptr = &kw[widx * c_in * c_out]; + + // Coordinate. x is for rows, y is for columns. + const int cx = BLOCK_SIZE * bx + ctx; + const int y = BLOCK_SIZE * N_LOOP * by + ty + - __ldg(&qkpos[widx]) + __ldg(&kpos[widx]); + + // Csub is used to store the element of the block sub-matrix + // that is computed by the thread + // float Csub[N_LOOP][4] = {0.0f}; + float padding[4] = {0.0f}; + + // Declaration of the shared memory array As used to + // store the sub-matrix of A + __shared__ float As[N_LOOP][BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Declaration of the shared memory array Bs used to + // store the sub-matrix of B + __shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Declaration of the shared memeory array Cs used to + // store the sub-matrix of C + // __shared__ float Cs[N_LOOP][BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Fragments to store As, Bs and Cs + nvcuda::wmma::fragment c[N_LOOP / 2]; + +#pragma unroll + for (int n = 0; n < N_LOOP / 2; n++){ + nvcuda::wmma::fill_fragment(c[n], 0.0f); + } + + // May not be necessary + __syncthreads(); + + // Loop over all the sub-matrices of A and B + // required to compute the block sub-matrix + for (int s = 0; s < c_in; s += BLOCK_SIZE) { + // Load the matrices from device memory + // to shared memory; each thread loads + // one element of each matrix + + // Kernel weight to Bs + *((float4*)(&Bs[ty][ctx])) = ((s + ty) < c_in && cx < c_out) ? + *((float4*)(kw_ptr + c_out * (s + ty) + cx)) : + *((float4*)(&padding[0])); + + // Input feature to As + for (int n = 0; n < N_LOOP; n++){ + + int y_temp = y + n * BLOCK_SIZE; + + // The thread deals with the x-th channel of the y-th output + int in_row = y_temp < __ldg(&kpos[widx + 1]) ? imap[y_temp] : -1; + + *((float4*)(&As[n][ty][ctx])) = ((s + ctx) < c_in && in_row > -1) ? + *((float4*)(&in_f[c_in * in_row + s + ctx])) : + *((float4*)(&padding[0])); + } + + // Synchronize to make sure the matrices are loaded + __syncthreads(); + + // Multiply the two matrices together using Tensor Core + // Load data from shmem to tensor core + // Just load Bs once +#pragma unroll + for (int k = 0; k < BLOCK_SIZE; k += K){ + nvcuda::wmma::fragment a[N_LOOP / 2]; + nvcuda::wmma::fragment b; + nvcuda::wmma::load_matrix_sync(b, &Bs[k][warp_col * N], BLOCK_SIZE + SKEW); +#pragma unroll + for (int t = 0; t < b.num_elements; t++) { + b.x[t] = nvcuda::wmma::__float_to_tf32(b.x[t]); + } +#pragma unroll + for (int n = 0; n < N_LOOP / 2; n++){ + nvcuda::wmma::load_matrix_sync(a[n], &As[n * MS + warpId / WS][warp_row % MS * M][k], BLOCK_SIZE + SKEW); +#pragma unroll + for (int t = 0; t < a[n].num_elements; t++) { + a[n].x[t] = nvcuda::wmma::__float_to_tf32(a[n].x[t]); + } + nvcuda::wmma::mma_sync(c[n], a[n], b, c[n]); + } + } + // Synchronize to make sure that the preceding + // computation is done before loading two new + // sub-matrices of A and B in the next iteration + __syncthreads(); + } + + // Store C fragments to shared memory + // Note that we reuse As for Cs storing +#pragma unroll + for (int n = 0; n < N_LOOP / 2; n++){ + nvcuda::wmma::store_matrix_sync(&As[n * MS + warpId / WS][warp_row % MS * M][warp_col * N], + c[n], BLOCK_SIZE + SKEW, nvcuda::wmma::mem_row_major); + } + + // Synchronize to make sure that all C fragments are + // stored into shared memory + __syncthreads(); + + // Write the block sub-matrix to device memory; + // each thread writes one element +#pragma unroll + for (int n = 0; n < N_LOOP; n++){ + int y_temp = y + n * BLOCK_SIZE; + int out_row = y_temp < __ldg(&kpos[widx + 1]) ? omap[y_temp] : -1; + if (out_row > -1 && cx < c_out){ +#pragma unroll + for (int c = 0; c < 4; c++){ + atomicAdd(&out_f[c_out * out_row + cx + c], As[n][ty][ctx + c]); + } + } + } +#else + #pragma message("TF32 kernels will not be compiled.") +#endif +} + + +/* +BLOCK_SIZE = 32, N_LOOP = 4, SKEW = 8, M = 16, K = 8, N = 16, +MS = 2, NS = 2, WS = 4 = MS x NS +blockDim.x = 8, blockDim.y = 32 +*/ +template +__global__ void fetch_on_demand_gemm_no_fusion_tf32( + const int knnz, + const int c_in, + const int c_out, + const float *__restrict__ in_f, + const float *__restrict__ kw, + float *out_f, + const int *imap, + const int *omap) { +#if __CUDA_ARCH__ >= 800 + // Block index + const int bx = blockIdx.x; + const int by = blockIdx.y; + + // Thread index + const int tx = threadIdx.x; + const int ty = threadIdx.y; + const int ctx = tx << 2; + const int tid = ty * blockDim.x + tx; + + // Warp index + const int warpId = tid / 32; + const int laneId = tid % 32; + const int warp_row = warpId / NS; + const int warp_col = warpId % NS; + + // Weight index + // const int widx = binary_search(qkpos, by * N_LOOP * BLOCK_SIZE, 0, k_vol); + // const float *kw_ptr = &kw[widx * c_in * c_out]; + const float *kw_ptr = &kw[0]; + + // Coordinate. x is for rows, y is for columns. + const int cx = BLOCK_SIZE * bx + ctx; + const int y = BLOCK_SIZE * N_LOOP * by + ty; + + // Csub is used to store the element of the block sub-matrix + // that is computed by the thread + // float Csub[N_LOOP][4] = {0.0f}; + float padding[4] = {0.0f}; + + // Declaration of the shared memory array As used to + // store the sub-matrix of A + __shared__ float As[N_LOOP][BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Declaration of the shared memory array Bs used to + // store the sub-matrix of B + __shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Declaration of the shared memeory array Cs used to + // store the sub-matrix of C + // __shared__ float Cs[N_LOOP][BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Fragments to store As, Bs and Cs + nvcuda::wmma::fragment c[N_LOOP / 2]; + +#pragma unroll + for (int n = 0; n < N_LOOP / 2; n++){ + nvcuda::wmma::fill_fragment(c[n], 0.0f); + } + + // May not be necessary + __syncthreads(); + + // Loop over all the sub-matrices of A and B + // required to compute the block sub-matrix + for (int s = 0; s < c_in; s += BLOCK_SIZE) { + // Load the matrices from device memory + // to shared memory; each thread loads + // one element of each matrix + + // Kernel weight to Bs + *((float4*)(&Bs[ty][ctx])) = ((s + ty) < c_in && cx < c_out) ? + *((float4*)(kw_ptr + c_out * (s + ty) + cx)) : + *((float4*)(&padding[0])); + + // Input feature to As + for (int n = 0; n < N_LOOP; n++){ + + int y_temp = y + n * BLOCK_SIZE; + + // The thread deals with the x-th channel of the y-th output + int in_row = y_temp < knnz ? imap[y_temp] : -1; + + *((float4*)(&As[n][ty][ctx])) = ((s + ctx) < c_in && in_row > -1) ? + *((float4*)(&in_f[c_in * in_row + s + ctx])) : + *((float4*)(&padding[0])); + } + + // Synchronize to make sure the matrices are loaded + __syncthreads(); + + // Multiply the two matrices together using Tensor Core + // Load data from shmem to tensor core + // Just load Bs once +#pragma unroll + for (int k = 0; k < BLOCK_SIZE; k += K){ + nvcuda::wmma::fragment a[N_LOOP / 2]; + nvcuda::wmma::fragment b; + nvcuda::wmma::load_matrix_sync(b, &Bs[k][warp_col * N], BLOCK_SIZE + SKEW); +#pragma unroll + for (int t = 0; t < b.num_elements; t++) { + b.x[t] = nvcuda::wmma::__float_to_tf32(b.x[t]); + } +#pragma unroll + for (int n = 0; n < N_LOOP / 2; n++){ + nvcuda::wmma::load_matrix_sync(a[n], &As[n * MS + warpId / WS][warp_row % MS * M][k], BLOCK_SIZE + SKEW); +#pragma unroll + for (int t = 0; t < a[n].num_elements; t++) { + a[n].x[t] = nvcuda::wmma::__float_to_tf32(a[n].x[t]); + } + nvcuda::wmma::mma_sync(c[n], a[n], b, c[n]); + } + } + // Synchronize to make sure that the preceding + // computation is done before loading two new + // sub-matrices of A and B in the next iteration + __syncthreads(); + } + + // Store C fragments to shared memory + // Note that we reuse As for Cs storing +#pragma unroll + for (int n = 0; n < N_LOOP / 2; n++){ + nvcuda::wmma::store_matrix_sync(&As[n * MS + warpId / WS][warp_row % MS * M][warp_col * N], + c[n], BLOCK_SIZE + SKEW, nvcuda::wmma::mem_row_major); + } + + // Synchronize to make sure that all C fragments are + // stored into shared memory + __syncthreads(); + + // Write the block sub-matrix to device memory; + // each thread writes one element +#pragma unroll + for (int n = 0; n < N_LOOP; n++){ + int y_temp = y + n * BLOCK_SIZE; + int out_row = y_temp < knnz ? omap[y_temp] : -1; + if (out_row > -1 && cx < c_out){ +#pragma unroll + for (int c = 0; c < 4; c++){ + // atomicAdd(&out_f[c_out * out_row + cx + c], As[n][ty][ctx + c]); + out_f[c_out * out_row + cx + c] += As[n][ty][ctx + c]; + } + } + } +#else + #pragma message("TF32 kernels will not be compiled.") +#endif +} +////////////////////////////// CUDA_ARCH >= 800 for TF32 /////////////////////////////////// + +/* +BLOCK_SIZE = 32, N_LOOP = 4, SKEW = 8, M = 16, K = 16, N = 16, +MS = 2, NS = 2, WS = 4 = MS x NS +blockDim.x = 8, blockDim.y = 32 +*/ +template +__global__ void fetch_on_demand_gemm_fp16_tc4( + const int *__restrict__ kpos, + const int *__restrict__ qkpos, + const int k_vol, + const int c_in, + const int c_out, + const half *__restrict__ in_f, + const half *__restrict__ kw, + half *out_f, + const int *imap, + const int *omap) { +#if __CUDA_ARCH__ >= 700 + + // Block index + const int bx = blockIdx.x; + const int by = blockIdx.y; + + // Thread index + const int tx = threadIdx.x; + const int ty = threadIdx.y; + const int ctx = tx << 2; + const int tid = ty * blockDim.x + tx; + + // Warp index + const int warpId = tid / 32; + // const int laneId = tid % 32; + const int warp_row = warpId / NS; + const int warp_col = warpId % NS; + + // Weight index + const int widx = binary_search(qkpos, by * N_LOOP * BLOCK_SIZE, 0, k_vol); + const half *kw_ptr = &kw[widx * c_in * c_out]; + + // Coordinate. x is for rows, y is for columns. + const int cx = BLOCK_SIZE * bx + ctx; + const int y = BLOCK_SIZE * N_LOOP * by + ty + - __ldg(&qkpos[widx]) + __ldg(&kpos[widx]); + + // Csub is used to store the element of the block sub-matrix + // that is computed by the thread + // float Csub[N_LOOP][4] = {0.0f}; + half padding[4] = {__float2half(0.0f)}; + + // Declaration of the shared memory array As used to + // store the sub-matrix of A + __shared__ half As[N_LOOP][BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Declaration of the shared memory array Bs used to + // store the sub-matrix of B + __shared__ half Bs[BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Declaration of the shared memeory array Cs used to + // store the sub-matrix of C + // __shared__ float Cs[N_LOOP][BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Fragments to store As, Bs and Cs + nvcuda::wmma::fragment c[N_LOOP / 2]; + +#pragma unroll + for (int n = 0; n < N_LOOP / 2; n++){ + nvcuda::wmma::fill_fragment(c[n], __float2half(0.0f)); + } + + // May not be necessary + __syncthreads(); + + // Loop over all the sub-matrices of A and B + // required to compute the block sub-matrix + for (int s = 0; s < c_in; s += BLOCK_SIZE) { + // Load the matrices from device memory + // to shared memory; each thread loads + // one element of each matrix + + // Kernel weight to Bs + *((float2*)(&Bs[ty][ctx])) = ((s + ty) < c_in && cx < c_out) ? + *((float2*)(kw_ptr + c_out * (s + ty) + cx)) : + *((float2*)(&padding[0])); + + // Input feature to As + for (int n = 0; n < N_LOOP; n++){ + + int y_temp = y + n * BLOCK_SIZE; + + // The thread deals with the x-th channel of the y-th output + int in_row = y_temp < __ldg(&kpos[widx + 1]) ? imap[y_temp] : -1; + + *((float2*)(&As[n][ty][ctx])) = ((s + ctx) < c_in && in_row > -1) ? + *((float2*)(&in_f[c_in * in_row + s + ctx])) : + *((float2*)(&padding[0])); + } + + // Synchronize to make sure the matrices are loaded + __syncthreads(); + + // Multiply the two matrices together using Tensor Core + // Load data from shmem to tensor core + // Just load Bs once +#pragma unroll + for (int k = 0; k < BLOCK_SIZE; k += K){ + nvcuda::wmma::fragment a[N_LOOP / 2]; + nvcuda::wmma::fragment b; + nvcuda::wmma::load_matrix_sync(b, &Bs[k][warp_col * N], BLOCK_SIZE + SKEW); +#pragma unroll + for (int n = 0; n < N_LOOP / 2; n++){ + nvcuda::wmma::load_matrix_sync(a[n], &As[n * MS + warpId / WS][warp_row % MS * M][k], BLOCK_SIZE + SKEW); + nvcuda::wmma::mma_sync(c[n], a[n], b, c[n]); + } + } + // Synchronize to make sure that the preceding + // computation is done before loading two new + // sub-matrices of A and B in the next iteration + __syncthreads(); + } + + // Store C fragments to shared memory + // Note that we reuse As for Cs storing +#pragma unroll + for (int n = 0; n < N_LOOP / 2; n++){ + nvcuda::wmma::store_matrix_sync(&As[n * MS + warpId / WS][warp_row % MS * M][warp_col * N], + c[n], BLOCK_SIZE + SKEW, nvcuda::wmma::mem_row_major); + } + + // Synchronize to make sure that all C fragments are + // stored into shared memory + __syncthreads(); + + // Write the block sub-matrix to device memory; + // each thread writes one element +#pragma unroll + for (int n = 0; n < N_LOOP; n++){ + int y_temp = y + n * BLOCK_SIZE; + int out_row = y_temp < __ldg(&kpos[widx + 1]) ? omap[y_temp] : -1; + if (out_row > -1 && cx < c_out){ +#pragma unroll + for (int c = 0; c < 4; c++){ + atomicAdd(&out_f[c_out * out_row + cx + c], As[n][ty][ctx + c]); + } + } + } +#else + #pragma message("FP16 kernels will not be compiled.") +#endif +} + + +/* +BLOCK_SIZE = 32, N_LOOP = 4, SKEW = 8, M = 16, K = 16, N = 16, +MS = 2, NS = 2, WS = 4 = MS x NS +blockDim.x = 8, blockDim.y = 32 +*/ +template +__global__ void fetch_on_demand_gemm_fp16_tc4_async( + const int *__restrict__ kpos, + const int *__restrict__ qkpos, + const int k_vol, + const int c_in, + const int c_out, + const half *__restrict__ in_f, + const half *__restrict__ kw, + half *out_f, + const int *imap, + const int *omap) { +#if __CUDA_ARCH__ >= 700 + + // Block index + const int bx = blockIdx.x; + const int by = blockIdx.y; + + // Thread index + const int tx = threadIdx.x; + const int ty = threadIdx.y; + const int ctx = tx << 2; + const int tid = ty * blockDim.x + tx; + + // Warp index + const int warpId = tid / 32; + // const int laneId = tid % 32; + const int warp_row = warpId / NS; + const int warp_col = warpId % NS; + + // Weight index + const int widx = binary_search(qkpos, by * N_LOOP * BLOCK_SIZE, 0, k_vol); + const half *kw_ptr = &kw[widx * c_in * c_out]; + + // Coordinate. x is for rows, y is for columns. + const int cx = BLOCK_SIZE * bx + ctx; + const int y = BLOCK_SIZE * N_LOOP * by + ty + - __ldg(&qkpos[widx]) + __ldg(&kpos[widx]); + + // Csub is used to store the element of the block sub-matrix + // that is computed by the thread + // float Csub[N_LOOP][4] = {0.0f}; + half padding[4] = {__float2half(0.0f)}; + + // Declaration of the shared memory array As used to + // store the sub-matrix of A + __shared__ half As[N_LOOP][BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Declaration of the shared memory array Bs used to + // store the sub-matrix of B + __shared__ half Bs[BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Declaration of the shared memeory array Cs used to + // store the sub-matrix of C + // __shared__ float Cs[N_LOOP][BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Pipelined copy between gmem and shmem + cuda::pipeline pipe = cuda::make_pipeline(); + const auto shape4 = cuda::aligned_size_t(sizeof(float2)); + + // Fragments to store As, Bs and Cs + nvcuda::wmma::fragment c[N_LOOP / 2]; + +#pragma unroll + for (int n = 0; n < N_LOOP / 2; n++){ + nvcuda::wmma::fill_fragment(c[n], __float2half(0.0f)); + } + + // May not be necessary + __syncthreads(); + + // Loop over all the sub-matrices of A and B + // required to compute the block sub-matrix + for (int s = 0; s < c_in; s += BLOCK_SIZE) { + // Load the matrices from device memory + // to shared memory; each thread loads + // one element of each matrix + + // Kernel weight to Bs + // const half *kw2Bs_ptr = ((s + ty) < c_in && cx < c_out) ? + // kw_ptr + c_out * (s + ty) + cx : &padding[0]; + pipe.producer_acquire(); + if ((s + ty) < c_in && cx < c_out){ + cuda::memcpy_async(&Bs[ty][ctx], kw_ptr + c_out * (s + ty) + cx, shape4, pipe); + } + else{ + cuda::memcpy_async(&Bs[ty][ctx], &padding[0], shape4, pipe); + } + // cuda::memcpy_async(&Bs[ty][ctx], kw2Bs_ptr, shape4, pipe); + pipe.producer_commit(); + + // Input feature to As + for (int n = 0; n < N_LOOP; n++){ + + int y_temp = y + n * BLOCK_SIZE; + + // The thread deals with the x-th channel of the y-th output + int in_row = y_temp < __ldg(&kpos[widx + 1]) ? imap[y_temp] : -1; + + // const half *inf2As_ptr = ((s + ctx) < c_in && in_row > -1) ? + // &in_f[c_in * in_row + s + ctx] : &padding[0]; + pipe.producer_acquire(); + if ((s + ctx) < c_in && in_row > -1){ + cuda::memcpy_async(&As[n][ty][ctx], &in_f[c_in * in_row + s + ctx], shape4, pipe); + } + else{ + cuda::memcpy_async(&As[n][ty][ctx], &padding[0], shape4, pipe); + } + // cuda::memcpy_async(&As[n][ty][ctx], inf2As_ptr, shape4, pipe); + pipe.producer_commit(); + } + + // Synchronize to make sure the matrices are loaded + cuda::pipeline_consumer_wait_prior<0>(pipe); + __syncthreads(); + + // Multiply the two matrices together using Tensor Core + // Load data from shmem to tensor core + // Just load Bs once +#pragma unroll + for (int k = 0; k < BLOCK_SIZE; k += K){ + nvcuda::wmma::fragment a[N_LOOP / 2]; + nvcuda::wmma::fragment b; + nvcuda::wmma::load_matrix_sync(b, &Bs[k][warp_col * N], BLOCK_SIZE + SKEW); +#pragma unroll + for (int n = 0; n < N_LOOP / 2; n++){ + nvcuda::wmma::load_matrix_sync(a[n], &As[n * MS + warpId / WS][warp_row % MS * M][k], BLOCK_SIZE + SKEW); + } +#pragma unroll + for (int n = 0; n < N_LOOP / 2; n++){ + nvcuda::wmma::mma_sync(c[n], a[n], b, c[n]); + } + } + // Synchronize to make sure that the preceding + // computation is done before loading two new + // sub-matrices of A and B in the next iteration + pipe.consumer_release(); + __syncthreads(); + } + + // Store C fragments to shared memory + // Note that we reuse As for Cs storing +#pragma unroll + for (int n = 0; n < N_LOOP / 2; n++){ + nvcuda::wmma::store_matrix_sync(&As[n * MS + warpId / WS][warp_row % MS * M][warp_col * N], + c[n], BLOCK_SIZE + SKEW, nvcuda::wmma::mem_row_major); + } + + // Synchronize to make sure that all C fragments are + // stored into shared memory + __syncthreads(); + + // Write the block sub-matrix to device memory; + // each thread writes one element +#pragma unroll + for (int n = 0; n < N_LOOP; n++){ + int y_temp = y + n * BLOCK_SIZE; + int out_row = y_temp < __ldg(&kpos[widx + 1]) ? omap[y_temp] : -1; + if (out_row > -1 && cx < c_out){ +#pragma unroll + for (int c = 0; c < 4; c++){ + atomicAdd(&out_f[c_out * out_row + cx + c], As[n][ty][ctx + c]); + } + } + } +#else + #pragma message("FP16 kernels with asynchronous copy will not be compiled.") +#endif +} + + +/* +BLOCK_SIZE = 32, N_LOOP = 4, SKEW = 8, M = 16, K = 16, N = 16, +MS = 2, NS = 2, WS = 4 = MS x NS +blockDim.x = 8, blockDim.y = 32 +*/ +template +__global__ void fetch_on_demand_gemm_no_fusion_fp16( + const int knnz, + const int c_in, + const int c_out, + const half *__restrict__ in_f, + const half *__restrict__ kw, + half *out_f, + const int *imap, + const int *omap) { +#if __CUDA_ARCH__ >= 700 + // Block index + const int bx = blockIdx.x; + const int by = blockIdx.y; + + // Thread index + const int tx = threadIdx.x; + const int ty = threadIdx.y; + const int ctx = tx << 2; + const int tid = ty * blockDim.x + tx; + + // Warp index + const int warpId = tid / 32; + // const int laneId = tid % 32; + const int warp_row = warpId / NS; + const int warp_col = warpId % NS; + + // Weight index + // const int widx = binary_search(qkpos, by * N_LOOP * BLOCK_SIZE, 0, k_vol); + // const half *kw_ptr = &kw[widx * c_in * c_out]; + const half *kw_ptr = &kw[0]; + + // Coordinate. x is for rows, y is for columns. + const int cx = BLOCK_SIZE * bx + ctx; + const int y = BLOCK_SIZE * N_LOOP * by + ty; + + // Csub is used to store the element of the block sub-matrix + // that is computed by the thread + // float Csub[N_LOOP][4] = {0.0f}; + half padding[4] = {__float2half(0.0f)}; + + // Declaration of the shared memory array As used to + // store the sub-matrix of A + __shared__ half As[N_LOOP][BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Declaration of the shared memory array Bs used to + // store the sub-matrix of B + __shared__ half Bs[BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Declaration of the shared memeory array Cs used to + // store the sub-matrix of C + // __shared__ float Cs[N_LOOP][BLOCK_SIZE][BLOCK_SIZE + SKEW]; + + // Fragments to store As, Bs and Cs + nvcuda::wmma::fragment c[N_LOOP / 2]; + +#pragma unroll + for (int n = 0; n < N_LOOP / 2; n++){ + nvcuda::wmma::fill_fragment(c[n], __float2half(0.0f)); + } + + // May not be necessary + __syncthreads(); + + // Loop over all the sub-matrices of A and B + // required to compute the block sub-matrix + for (int s = 0; s < c_in; s += BLOCK_SIZE) { + // Load the matrices from device memory + // to shared memory; each thread loads + // one element of each matrix + + // Kernel weight to Bs + *((float2*)(&Bs[ty][ctx])) = ((s + ty) < c_in && cx < c_out) ? + *((float2*)(kw_ptr + c_out * (s + ty) + cx)) : + *((float2*)(&padding[0])); + + // Input feature to As + for (int n = 0; n < N_LOOP; n++){ + + int y_temp = y + n * BLOCK_SIZE; + + // The thread deals with the x-th channel of the y-th output + int in_row = y_temp < knnz ? imap[y_temp] : -1; + + *((float2*)(&As[n][ty][ctx])) = ((s + ctx) < c_in && in_row > -1) ? + *((float2*)(&in_f[c_in * in_row + s + ctx])) : + *((float2*)(&padding[0])); + } + + // Synchronize to make sure the matrices are loaded + __syncthreads(); + + // Multiply the two matrices together using Tensor Core + // Load data from shmem to tensor core + // Just load Bs once +#pragma unroll + for (int k = 0; k < BLOCK_SIZE; k += K){ + nvcuda::wmma::fragment a[N_LOOP / 2]; + nvcuda::wmma::fragment b; + nvcuda::wmma::load_matrix_sync(b, &Bs[k][warp_col * N], BLOCK_SIZE + SKEW); +#pragma unroll + for (int n = 0; n < N_LOOP / 2; n++){ + nvcuda::wmma::load_matrix_sync(a[n], &As[n * MS + warpId / WS][warp_row % MS * M][k], BLOCK_SIZE + SKEW); + nvcuda::wmma::mma_sync(c[n], a[n], b, c[n]); + } + } + // Synchronize to make sure that the preceding + // computation is done before loading two new + // sub-matrices of A and B in the next iteration + __syncthreads(); + } + + // Store C fragments to shared memory + // Note that we reuse As for Cs storing +#pragma unroll + for (int n = 0; n < N_LOOP / 2; n++){ + nvcuda::wmma::store_matrix_sync(&As[n * MS + warpId / WS][warp_row % MS * M][warp_col * N], + c[n], BLOCK_SIZE + SKEW, nvcuda::wmma::mem_row_major); + } + + // Synchronize to make sure that all C fragments are + // stored into shared memory + __syncthreads(); + + // Write the block sub-matrix to device memory; + // each thread writes one element +#pragma unroll + for (int n = 0; n < N_LOOP; n++){ + int y_temp = y + n * BLOCK_SIZE; + int out_row = y_temp < knnz ? omap[y_temp] : -1; + if (out_row > -1 && cx < c_out){ +#pragma unroll + for (int c = 0; c < 4; c++){ + // out_f[c_out * out_row + cx + c] += As[n][ty][ctx + c]; + out_f[c_out * out_row + cx + c] = + __hadd(out_f[c_out * out_row + cx + c], As[n][ty][ctx + c]); + } + } + } +#else + #pragma message("FP16 kernels will not be compiled.") +#endif +} +///////////////////////////////// CUDA_ARCH >= 700 /////////////////////////////////// + + + + +// in_feat: (N, c) N=# of input points, c = input channels +// out_feat: (M, o) M=# of output points, o = output channels +// for stride=1, M=N. For stride>1, the N input coords +// are requantized to M points with grid size (stride * +// cur_stride) +// kernel: (k^3, c, o) for a 3D convolution of length k +// neighbor_map: (a, 2) the hash table query results from in_coords to +// out_coords +// where neighbor_map[:,0] is the index of the input +// feature and neighbor_map[:,1] is the index of the output +// feature +// neighbor_offset: (k^3) count of active weights based on neighbor_map +// with unused weights having 0 and neighbor_offset[k^3/2] +// holding w[0,0]. +at::Tensor conv_forward_fetch_on_demand_cuda( + at::Tensor in_feat, at::Tensor& out_feat, at::Tensor kernel, + at::Tensor neighbor_map, const int sum_nnz, + at::Tensor neighbor_address, at::Tensor q_neighbor_address, + const int output_size, const int qsum_nnz, const bool transpose, + const bool allow_tf32, const bool allow_fp16) { + + // int sum_nnz = (int)torch::sum(neighbor_offset).item(); + int input_size = in_feat.size(0); + int in_channel = in_feat.size(1); + int out_channel = kernel.size(2); + int k_vol = kernel.size(0); + // int *knnz_ptr = neighbor_offset.data_ptr(); + // int *in_map_ptr = in_neighbor_map.data_ptr(); + // int *out_map_ptr = out_neighbor_map.data_ptr(); + int *kpos_ptr = neighbor_address.data_ptr(); + int *qkpos_ptr = q_neighbor_address.data_ptr(); + int *in_map_ptr; + int *out_map_ptr; + if (transpose){ + in_map_ptr = neighbor_map.data_ptr() + sum_nnz; + out_map_ptr = neighbor_map.data_ptr(); + } + else{ + in_map_ptr = neighbor_map.data_ptr(); + out_map_ptr = neighbor_map.data_ptr() + sum_nnz; + } + + // memory allocation +// at::Tensor out_feat = torch::zeros({output_size, out_channel}, +// at::device(in_feat.device()).dtype(in_feat.scalar_type())); + // at::Tensor kpos = torch::zeros({k_vol + 1}, + // at::device(in_feat.device()).dtype(at::ScalarType::Int)); + // at::Tensor qkpos = torch::zeros({k_vol + 1}, + // at::device(in_feat.device()).dtype(at::ScalarType::Int)); + // int *kpos_ptr = kpos.data_ptr(); + // int *qkpos_ptr = qkpos.data_ptr(); + + // should be modified in the future + int mid_kernel = (k_vol % 2 == 1) ? k_vol / 2 : 0; + + bool data_type_half = in_feat.scalar_type() == at::ScalarType::Half; + // bool precompute_mid = (input_size == output_size && k_vol % 2 == 1); + bool precompute_mid = false; + + // exclusive_scan_for_kernel_quantified<<<1, k_vol, 0, 0>>>( + // k_vol + 1, knnz_ptr, 128, kpos_ptr, qkpos_ptr + // ); + + // int qsum_nnz = qkpos[k_vol].item(); + // printf("%d", qsum_nnz); + + if (data_type_half && allow_fp16){ + if (in_channel % 4 == 0 && out_channel % 4 == 0){ + if (in_channel <= 16 || out_channel <= 16){ + fetch_on_demand_gemm_fp16_4_once<16, 4, 8> + <<>>( + kpos_ptr, qkpos_ptr, k_vol, in_channel, out_channel, + reinterpret_cast(in_feat.data_ptr()), + reinterpret_cast(kernel.data_ptr()), + reinterpret_cast(out_feat.data_ptr()), + in_map_ptr, out_map_ptr); + } + else{ + if (allow_tf32){ + fetch_on_demand_gemm_fp16_tc4_async<32, 4, 8, 16, 16, 16, 4, 2, 2> + <<>>( + kpos_ptr, qkpos_ptr, k_vol, in_channel, out_channel, + reinterpret_cast(in_feat.data_ptr()), + reinterpret_cast(kernel.data_ptr()), + reinterpret_cast(out_feat.data_ptr()), + in_map_ptr, out_map_ptr); + } + else{ + fetch_on_demand_gemm_fp16_tc4<32, 4, 8, 16, 16, 16, 4, 2, 2> + <<>>( + kpos_ptr, qkpos_ptr, k_vol, in_channel, out_channel, + reinterpret_cast(in_feat.data_ptr()), + reinterpret_cast(kernel.data_ptr()), + reinterpret_cast(out_feat.data_ptr()), + in_map_ptr, out_map_ptr); + } + } + } + else if (in_channel % 2 == 0 && out_channel % 2 == 0){ + fetch_on_demand_gemm_fp16_2<16, 8, 8> + <<>>( + kpos_ptr, qkpos_ptr, k_vol, in_channel, out_channel, + reinterpret_cast(in_feat.data_ptr()), + reinterpret_cast(kernel.data_ptr()), + reinterpret_cast(out_feat.data_ptr()), + in_map_ptr, out_map_ptr); + } + else{ + fetch_on_demand_gemm_fp16_1<16, 4, 8> + <<>>( + kpos_ptr, qkpos_ptr, k_vol, in_channel, out_channel, + reinterpret_cast(in_feat.data_ptr()), + reinterpret_cast(kernel.data_ptr()), + reinterpret_cast(out_feat.data_ptr()), + in_map_ptr, out_map_ptr); + } + } + else{ + if(in_channel % 4 == 0 && out_channel % 4 ==0){ + if (in_channel <= 16 && out_channel <= 16){ + fetch_on_demand_gemm_fp32_once<16, 4, 8> + <<>>( + kpos_ptr, qkpos_ptr, k_vol, in_channel, out_channel, + in_feat.data_ptr(), kernel.data_ptr(), out_feat.data_ptr(), + in_map_ptr, out_map_ptr); + } + else{ + if (allow_tf32){ + fetch_on_demand_gemm_tf32<32, 4, 8, 16, 8, 16, 4, 2, 2> + <<>>( + kpos_ptr, qkpos_ptr, k_vol, in_channel, out_channel, + in_feat.data_ptr(), kernel.data_ptr(), out_feat.data_ptr(), + in_map_ptr, out_map_ptr); + } + else{ + fetch_on_demand_gemm_fp32<32, 4, 8> + <<>>( + kpos_ptr, qkpos_ptr, k_vol, in_channel, out_channel, + in_feat.data_ptr(), kernel.data_ptr(), out_feat.data_ptr(), + in_map_ptr, out_map_ptr); + } + } + } + else if (in_channel % 2 == 0 && out_channel % 2 == 0){ + fetch_on_demand_gemm_fp32_2<16, 8, 8> + <<>>( + kpos_ptr, qkpos_ptr, k_vol, in_channel, out_channel, + in_feat.data_ptr(), kernel.data_ptr(), out_feat.data_ptr(), + in_map_ptr, out_map_ptr); + } + else{ + fetch_on_demand_gemm_fp32_1<16, 4, 8> + <<>>( + kpos_ptr, qkpos_ptr, k_vol, in_channel, out_channel, + in_feat.data_ptr(), kernel.data_ptr(), out_feat.data_ptr(), + in_map_ptr, out_map_ptr); + } + } + + // precomputation only for odd channel size + // bool precompute_mid = (input_size == output_size && k_vol % 2 == 1); + if (precompute_mid){ + at::addmm_out(out_feat, out_feat, in_feat, kernel[mid_kernel]); + } + + return out_feat; +} + + +at::Tensor conv_forward_fetch_on_demand_no_fusion_cuda( + at::Tensor in_feat, at::Tensor kernel, + at::Tensor neighbor_map, at::Tensor neighbor_offset, + const int sum_nnz, const int output_size, const bool transpose, + const bool allow_tf32, const bool allow_fp16){ + + // int sum_nnz = (int)torch::sum(neighbor_offset).item(); + int input_size = in_feat.size(0); + int in_channel = in_feat.size(1); + int out_channel = kernel.size(2); + int k_vol = kernel.size(0); + int *knnz_ptr = neighbor_offset.data_ptr(); + // int *in_map_ptr = in_neighbor_map.data_ptr(); + // int *out_map_ptr = out_neighbor_map.data_ptr(); + // int *kpos_ptr = neighbor_address.data_ptr(); + // int *qkpos_ptr = q_neighbor_address.data_ptr(); + int *in_map_ptr; + int *out_map_ptr; + if (transpose){ + in_map_ptr = neighbor_map.data_ptr() + sum_nnz; + out_map_ptr = neighbor_map.data_ptr(); + } + else{ + in_map_ptr = neighbor_map.data_ptr(); + out_map_ptr = neighbor_map.data_ptr() + sum_nnz; + } + + // memory allocation + at::Tensor out_feat = torch::zeros({output_size, out_channel}, + at::device(in_feat.device()).dtype(in_feat.scalar_type())); + // at::Tensor kpos = torch::zeros({k_vol + 1}, + // at::device(in_feat.device()).dtype(at::ScalarType::Int)); + // at::Tensor qkpos = torch::zeros({k_vol + 1}, + // at::device(in_feat.device()).dtype(at::ScalarType::Int)); + // int *kpos_ptr = kpos.data_ptr(); + // int *qkpos_ptr = qkpos.data_ptr(); + + // should be modified in the future + int mid_kernel = (k_vol % 2 == 1) ? k_vol / 2 : 0; + + bool data_type_half = in_feat.scalar_type() == at::ScalarType::Half; + // bool precompute_mid = (input_size == output_size && k_vol % 2 == 1); + bool precompute_mid = false; + + /********************************************************************/ + // loop over all kernel offsets + int cur_idx = 0; + // int stream_id = 0; + for (int k = 0; k < k_vol; k++){ + int cur_nnz = knnz_ptr[k]; + + if (cur_nnz == 0){continue;} + + // size_t gridnum_x = DIV_UP(out_channel, 32); + // size_t gridnum_y = DIV_UP(cur_nnz, 32); + + if (data_type_half && allow_fp16){ + if (in_channel % 4 == 0 && out_channel % 4 == 0){ + fetch_on_demand_gemm_no_fusion_fp16<32, 4, 8, 16, 16, 16, 4, 2, 2> + <<>>( + cur_nnz, in_channel, out_channel, + reinterpret_cast(in_feat.data_ptr()), + reinterpret_cast(kernel.data_ptr() + + k * in_channel * out_channel), + reinterpret_cast(out_feat.data_ptr()), + &in_map_ptr[cur_idx], &out_map_ptr[cur_idx] + ); + } + else{ + fetch_on_demand_gemm_no_fusion_fp16_1<16, 4, 8> + <<>>( + cur_nnz, in_channel, out_channel, + reinterpret_cast(in_feat.data_ptr()), + reinterpret_cast(kernel.data_ptr() + + k * in_channel * out_channel), + reinterpret_cast(out_feat.data_ptr()), + &in_map_ptr[cur_idx], &out_map_ptr[cur_idx] + ); + } + } + else{ + if (in_channel % 4 == 0 && out_channel % 4 == 0){ + if (allow_tf32){ + fetch_on_demand_gemm_no_fusion_tf32<32, 4, 8, 16, 8, 16, 4, 2, 2> + <<>>( + cur_nnz, in_channel, out_channel, + in_feat.data_ptr(), + (kernel.data_ptr() + k * in_channel * out_channel), + out_feat.data_ptr(), + &in_map_ptr[cur_idx], &out_map_ptr[cur_idx] + ); + } + else{ + fetch_on_demand_gemm_no_fusion_fp32<32, 4, 8> + <<>>( + cur_nnz, in_channel, out_channel, + in_feat.data_ptr(), + (kernel.data_ptr() + k * in_channel * out_channel), + out_feat.data_ptr(), + &in_map_ptr[cur_idx], &out_map_ptr[cur_idx] + ); + } + } + else{ + fetch_on_demand_gemm_no_fusion_fp32_1<16, 4, 8> + <<>>( + cur_nnz, in_channel, out_channel, + in_feat.data_ptr(), + (kernel.data_ptr() + k * in_channel * out_channel), + out_feat.data_ptr(), + &in_map_ptr[cur_idx], &out_map_ptr[cur_idx] + ); + } + } + + cur_idx += cur_nnz; + } + // precomputation only for odd channel size + // bool precompute_mid = (input_size == output_size && k_vol % 2 == 1); + if (precompute_mid){ + at::addmm_out(out_feat, out_feat, in_feat, kernel[mid_kernel]); + } + + return out_feat; +} + +} +} \ No newline at end of file diff --git a/impl/torch/sparse_kernel.h b/impl/torch/sparse_kernel.h index 1642d89b4..98229ff24 100644 --- a/impl/torch/sparse_kernel.h +++ b/impl/torch/sparse_kernel.h @@ -8,5 +8,18 @@ namespace ops { at::Tensor row_balance_row_major_seq_reduce_kernel(at::Tensor& out, const at::Tensor& row_ptr, const at::Tensor& col_ind, const at::Tensor& value, at::Tensor& input); +at::Tensor conv_forward_fetch_on_demand_cuda( + at::Tensor in_feat, at::Tensor& out_feat, at::Tensor kernel, + at::Tensor neighbor_map, const int sum_nnz, + at::Tensor neighbor_address, at::Tensor q_neighbor_address, + const int output_size, const int qsum_nnz, const bool transpose, + const bool allow_tf32, const bool allow_fp16); + +// at::Tensor conv_forward_fetch_on_demand_no_fusion_cuda( +// at::Tensor in_feat, at::Tensor kernel, +// at::Tensor neighbor_map, at::Tensor neighbor_offset, +// const int sum_nnz, const int output_size, const bool transpose, +// const bool allow_tf32, const bool allow_fp16); + } // namespace ops } // namespace sparse diff --git a/proto/include/diopi/functions_sparse.h b/proto/include/diopi/functions_sparse.h index 539015f7e..fee6eb2ae 100644 --- a/proto/include/diopi/functions_sparse.h +++ b/proto/include/diopi/functions_sparse.h @@ -22,6 +22,34 @@ extern "C" { */ DIOPI_API diopiError_t diopiSpMM(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t mat2); +/** + * @brief Fetch-on-demand SpConv + * @param[in] ctx Context environment. + * @param[in] in_feat (N, c) N=# of input points, c = input channels + * @param[out] out_feat (M, o) M=# of output points, o = output channels + * @param[in] kernel (k^3, c, o) for a 3D convolution of length k + * @param[in] neighbor_map (a, 2) the hash table query results from in_coords to out_coords + */ +DIOPI_API diopiError_t diopiSpConv(diopiContextHandle_t ctx, diopiTensorHandle_t out_feat, diopiTensorHandle_t in_feat, + diopiTensorHandle_t kernel, diopiTensorHandle_t neighbor_map, const int sum_nnz, + diopiTensorHandle_t neighbor_address, diopiTensorHandle_t q_neighbor_address, const int output_size, + const int qsum_nnz, const bool transpose, const bool allow_tf32, const bool allow_fp16); + + + +// /** +// * @brief Fetch-on-demand SpConv +// * @param[in] ctx Context environment. +// * @param[in] in_feat (N, c) N=# of input points, c = input channels +// * @param[out] out_feat (M, o) M=# of output points, o = output channels +// * @param[in] kernel (k^3, c, o) for a 3D convolution of length k +// * @param[in] neighbor_map (a, 2) the hash table query results from in_coords to out_coords +// */ +// DIOPI_API diopiError_t diopiSpConv(diopiContextHandle_t ctx, diopiTensorHandle_t in_feat, diopiTensorHandle_t out_feat, +// diopiConstTensorHandle_t kernel, diopiConstTensorHandle_t neighbor_map, const int sum_nnz, + // diopiTensorHandle_t neighbor_address, diopiTensorHandle_t q_neighbor_address, const int output_size, + // const int qsum_nnz, const bool transpose, const bool allow_tf32, const bool allow_fp16); + #if defined(__cplusplus) } #endif // __cplusplus