Skip to content

Commit

Permalink
Add support for building CUDA extension on Windows (#396)
Browse files Browse the repository at this point in the history
* Enable FP6-LLM kernel build on Windows

* fix benchmark script

* update setup.py

* update

* fix indent

* add -t=0 for linux

---------

Co-authored-by: Matthew Douglas <[email protected]>
  • Loading branch information
gau-nernst and matthewdouglas authored Jun 18, 2024
1 parent f5b6ec9 commit d0af941
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 48 deletions.
6 changes: 3 additions & 3 deletions benchmarks/benchmark_fp6_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@


def benchmark(m: int, k: int, n: int):
fp6_weight = torch.randint(256, size=(n, k // 4 * 3), dtype=torch.uint8, device="cuda")
fp6_weight = torch.randint(256, size=(n, k * 3 // 4), dtype=torch.uint8, device="cuda")
scales = torch.rand(n, dtype=torch.half, device="cuda") + 0.5
fp6_linear = Fp6LlmLinear(fp6_weight.view(torch.int32), scales)
fp6_linear = Fp6LlmLinear(fp6_weight, scales)

fp16_linear = nn.Linear(k, n, bias=True, dtype=torch.half, device="cuda")
fp16_linear.weight.data = from_tc_float6_e3m2(fp6_weight.view(-1), n, k, dtype=torch.half) * scales[:, None]
fp16_linear.weight.data = from_tc_float6_e3m2(fp6_weight, dtype=torch.half) * scales[:, None]

fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")
fp6_output = fp6_linear(fp16_act)
Expand Down
50 changes: 36 additions & 14 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def read_version(file_path="version.txt"):
CUDAExtension,
BuildExtension,
CUDA_HOME,
IS_WINDOWS
)


Expand All @@ -52,20 +53,41 @@ def get_extensions():
use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
extension = CUDAExtension if use_cuda else CppExtension

extra_link_args = []
extra_compile_args = {
"cxx": [
"-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
],
"nvcc": [
"-O3" if not debug_mode else "-O0",
]
}
if debug_mode:
extra_compile_args["cxx"].append("-g")
extra_compile_args["nvcc"].append("-g")
extra_link_args.extend(["-O0", "-g"])
if not IS_WINDOWS:
extra_link_args = []
extra_compile_args = {
"cxx": [
"-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
],
"nvcc": [
"-O3" if not debug_mode else "-O0",
"-t=0",
]
}

if debug_mode:
extra_compile_args["cxx"].append("-g")
extra_compile_args["nvcc"].append("-g")
extra_link_args.extend(["-O0", "-g"])

else:
extra_link_args = []
extra_compile_args = {
"cxx": [
"/O2" if not debug_mode else "/Od",
"/permissive-"
],
"nvcc": [
"-O3" if not debug_mode else "-O0",
"-t=0",
]
}

if debug_mode:
extra_compile_args["cxx"].append("/ZI")
extra_compile_args["nvcc"].append("-g")
extra_link_args.append("/DEBUG")

this_dir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(this_dir, "torchao", "csrc")
Expand Down
9 changes: 5 additions & 4 deletions torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/kernel_matmul.cuh
// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/kernel_matmul.cuh

#include "configs.h"
#include "utils_gmem.cuh"
Expand Down Expand Up @@ -133,11 +133,12 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales,
uint32_t* __restrict__ write_SPTR_Frag1 = AFrag_2BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
uint32_t* __restrict__ write_SPTR_Frag2 = AFrag_4BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
// Trible-Buffer for B Tile
half __restrict__ (*read_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
// MODIFICATION NOTE: to support MSVC, half __restrict__ (*read_SPTR ) is changed to below. similarly for read2_SPTR and write_SPTR.
half (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
#ifdef PIPELINE_LEVEL_SMEM
half __restrict__ (*read2_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
half (* __restrict__ read2_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
#endif
half __restrict__ (*write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
half (* __restrict__ write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
//
bool GlobalCopy = (tile_id_k+PIPELINE_LEVEL_GMEM-1) < NumIter;
// Copying A tile from Global to Register, Bypassing L1, using double-buffer
Expand Down
15 changes: 10 additions & 5 deletions torchao/csrc/cuda/fp6_llm/ptx_mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/ptx_mma.cuh
// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/ptx_mma.cuh

/***************************************************************************
* Copyright 2023 The FLash-LLM Authors. All rights reserved.
Expand All @@ -36,11 +36,14 @@
#include <assert.h>
#include "configs.h"

// MODIFICATION NOTE: to support MSVC
// - uint32_t __restrict__ Reg[][4] is changed to uint32_t (* __restrict__ Reg)[4]
// - half __restrict__ (*read_SPTR) is changed to half (* __restrict__ read_SPTR)
#ifdef PIPELINE_LEVEL_SMEM
template <typename TilingConfig>
__device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[][4],
half __restrict__ (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
int slice_id) {
__device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[4],
half (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
int slice_id) {
#ifdef DEBUG_MODE
static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) );
#endif
Expand Down Expand Up @@ -112,8 +115,10 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[
}
#endif

// MODIFICATION NOTE: to support MSVC, the function signature is changed from
// MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b).
__device__ __forceinline__ void
MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b)
MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t * __restrict__ b)
{
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{ %0, %1, %2, %3},"
Expand Down
8 changes: 5 additions & 3 deletions torchao/csrc/cuda/fp6_llm/utils_core.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_core.cuh
// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_core.cuh

#ifndef UTILS_CORE_CUH
#define UTILS_CORE_CUH
Expand All @@ -35,12 +35,13 @@ __device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], u
}
}

// MODIFICATION NOTE: to support MSVC, half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below.
template <typename TilingConfig>
__device__ __forceinline__ void initialize_mma_slice(uint32_t (*a)[4],
uint32_t (*b)[4],
uint32_t* __restrict__ A1_SPTR_read,
uint32_t* __restrict__ A2_SPTR_read,
half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
half (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
uint32_t* RPTR_Scales)
{
// Writing registers
Expand All @@ -53,13 +54,14 @@ __device__ __forceinline__ void initialize_mma_slice(uint32_t (
B_FromSharedToReg<TilingConfig>(b, B_SPTR_read, 0); // Loading B from shared to registers
}

// MODIFICATION NOTE: to support MSVC, half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below.
template <typename TilingConfig>
__device__ __forceinline__ void core_mma_slice(float c[][REG_PER_THREAD_C_TENSOR_16_16],
uint32_t (*a)[4],
uint32_t (*b)[4],
uint32_t* __restrict__ A1_SPTR_read,
uint32_t* __restrict__ A2_SPTR_read,
half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
half (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
uint32_t* RPTR_Scales,
int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1 for prefetching
{
Expand Down
13 changes: 7 additions & 6 deletions torchao/csrc/cuda/fp6_llm/utils_gmem.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_gmem.cuh
// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_gmem.cuh

#ifndef UTILS_GMEM_CUH
#define UTILS_GMEM_CUH
Expand Down Expand Up @@ -57,17 +57,18 @@ __device__ __forceinline__ void CopyFromGlobalToShared_Scales(half* SPTR_QuantSc
for(int i=0; i<2; i++) SPTR_QuantScales[Offset_Shared+i] = GPTR_A_Scales[Offset_Global+i*8];
}

// MODIFICATION NOTE: to support MSVC, half __restrict__ (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below.
/*
* (1) Copying X rows * 64 columns of FP16 values, originally in row major
* (2) Copying 64 rows * X columns of FP16 values, originally in column major
* 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8 Threads
*/
template<int MaxNumOfLinesToCopy, int BLOCK_WARPS>
__device__ __forceinline__ void CopyFromGlobalToShared(half __restrict__ (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
const half* GlobalPTR,
const int GlobalStride,
const int NumOfLinesLeft, // To support arbitrary N dimensions.
bool Pred = true) {
__device__ __forceinline__ void CopyFromGlobalToShared(half (* __restrict__ SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
const half* GlobalPTR,
const int GlobalStride,
const int NumOfLinesLeft, // To support arbitrary N dimensions.
bool Pred = true) {
// static parameters: 1 Group (8 Threads) can copy 1 line (64 FP16) each time
const int NumOfThreads = BLOCK_WARPS * WARP_SIZE;
const int NumOfGroups = NumOfThreads / 8;
Expand Down
30 changes: 17 additions & 13 deletions torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_parallel_dequant.cuh
// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_parallel_dequant.cuh
// To support MSVC, all instances of u_int32_t are changed to uint32_t.

#ifndef UTILS_PARALLELDEQUANT_CUH
#define UTILS_PARALLELDEQUANT_CUH
Expand All @@ -26,7 +27,7 @@
* Outputs: R1, R2
* Note: Simplified Exponent calculation is applied.
*/
__device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t *R1, u_int32_t *R2) {
__device__ __forceinline__ void FP6_FP16_Cast_4Way(uint32_t *R1, uint32_t *R2) {
*R2 = *R1 & 0x80808080;
*R1 = *R1 >> 2;
*R1 = *R1 & 0x1f1f1f1f;
Expand All @@ -41,7 +42,7 @@ __device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t *R1, u_int32_t *R2)
* Outputs: R1, R2
* Note: Simplified Exponent calculation is NOT applied.
*/
__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t *R1, u_int32_t *R2) {
__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(uint32_t *R1, uint32_t *R2) {
//*R2 = *R1 & 0x80808080;
*R2 = *R1 & 0xc0c0c0c0;
*R1 = *R1 >> 2;
Expand All @@ -63,7 +64,7 @@ __device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t *R1, u_int32_
//*R2 = 0x3c003c00;
}

__device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Scale) {
__device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scale) {
half* FP16_1 = reinterpret_cast<half*>(&PackedFP16Pair);
half* FP16_2 = FP16_1 + 1;
uint32_t output;
Expand All @@ -73,16 +74,19 @@ __device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Sc
return output;
}

__device__ __forceinline__ void Dequant_32FP6_4Way(u_int32_t __restrict__ Reg[][4],
u_int32_t __restrict__ *read_RPTR_Frag1,
u_int32_t __restrict__ *read_RPTR_Frag2,
u_int32_t *Scales) {
u_int32_t *OutputRegs = reinterpret_cast<u_int32_t*> (Reg);
u_int32_t *Frag1_PTR = read_RPTR_Frag1;
u_int32_t *Frag2_PTR = read_RPTR_Frag2;
// MODIFICATION NOTE: to support MSVC
// - u_int32_t __restrict__ Reg[][4] is changed to below.
// - u_int32_t __restrict__ *read_RPTR_Frag1 is changed to below. similarly for read_RPTR_Frag2
__device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (* __restrict__ Reg)[4],
uint32_t * __restrict__ read_RPTR_Frag1,
uint32_t * __restrict__ read_RPTR_Frag2,
uint32_t * Scales) {
uint32_t *OutputRegs = reinterpret_cast<uint32_t*> (Reg);
uint32_t *Frag1_PTR = read_RPTR_Frag1;
uint32_t *Frag2_PTR = read_RPTR_Frag2;
half *Scale_RPTR = reinterpret_cast<half*>(Scales);
u_int32_t Packed_FP6 = 0;
u_int32_t tmp = 0;
uint32_t Packed_FP6 = 0;
uint32_t tmp = 0;
// Dequantizing 32 FP6, each Loop dequantizing 4 FP6
#pragma unroll(8)
for(int i=0; i<8; i++) {
Expand Down

0 comments on commit d0af941

Please sign in to comment.