forked from pytorch/FBGEMM
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MoE BMM INT4 rowwise weight-only (pytorch#3219)
Summary: Pull Request resolved: pytorch#3219 X-link: facebookresearch/FBGEMM#316 Marlin int4 weight-only with loopover for bmm performs great (**up to 7x faster** compared to bf16 bmm) when dim M is small to medium size (e.g., < 256) in decode; For larger dim M, we could leverage this bmm int4 rowwise weight-only kernel in prefill that is around **1.5x faster** than marlin int4 loopover and maintain the same accuracy More results can be found in this [data sheet](https://docs.google.com/spreadsheets/d/12JWt3SqX_1GSLKwjGyt0KQl9SMWDF0r0C63MMKsE9JM/edit?usp=sharing) Reviewed By: jianyuh Differential Revision: D63818529 fbshipit-source-id: 127e841fa7c6c1ce810b6e8b6e35907eeaecafd6
- Loading branch information
1 parent
a0966e8
commit d3eae1d
Showing
4 changed files
with
366 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
298 changes: 298 additions & 0 deletions
298
fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,298 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <ATen/ATen.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <cutlass/util/device_memory.h> | ||
#include <cutlass/util/packed_stride.hpp> | ||
|
||
// clang-format off | ||
// The fixed ordering of the headers is required for CUTLASS 3.2+ | ||
#include <cute/tensor.hpp> | ||
#include <cutlass/gemm/collective/collective_builder.hpp> // @manual | ||
#include <cutlass/gemm/device/gemm_universal_adapter.h> // @manual | ||
#include <cutlass/epilogue/collective/collective_builder.hpp> // @manual | ||
// clang-format on | ||
|
||
#include "cutlass_extensions/include/kernel_mode.h" | ||
|
||
namespace fbgemm_gpu { | ||
|
||
#if CUDART_VERSION >= 12000 | ||
|
||
template < | ||
int TB_M, | ||
int TB_N, | ||
int TB_K, | ||
int TBS_M, | ||
int TBS_N, | ||
int TBS_K, | ||
bool PONG, | ||
typename WEIGHT_SCALE_DTYPE> | ||
at::Tensor bf16i4bf16_rowwise_batched_impl( | ||
at::Tensor X, // BF16 | ||
at::Tensor WQ, // INT4 | ||
at::Tensor w_scale, | ||
at::Tensor w_zp) { | ||
// XQ: B x M x K | ||
// WQ: B x N x K | ||
// output: B x M x N | ||
int B = X.size(0); | ||
int M = X.size(1); | ||
int N = WQ.size(1); | ||
int K = X.size(2); | ||
|
||
int num_groups = w_scale.size(0) / B; | ||
|
||
TORCH_CHECK(X.is_cuda() && X.is_contiguous()); | ||
TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); | ||
TORCH_CHECK(w_scale.is_cuda() && w_scale.is_contiguous()); | ||
TORCH_CHECK(w_zp.is_cuda() && w_zp.is_contiguous()); | ||
TORCH_CHECK(K >= num_groups && K % num_groups == 0); | ||
|
||
int group_size = K / num_groups; | ||
|
||
auto Y = at::empty({B, M, N}, X.options().dtype(at::kBFloat16)); | ||
|
||
using ElementInputA = cutlass::bfloat16_t; | ||
using LayoutInputA = cutlass::layout::ColumnMajor; | ||
constexpr int AlignmentInputA = | ||
128 / | ||
cutlass::sizeof_bits< | ||
ElementInputA>::value; // Memory access granularity/alignment of A | ||
// matrix in units of elements (up to 16 bytes) | ||
|
||
using ElementInputB = cutlass::int4b_t; | ||
using LayoutInputB = cutlass::layout::RowMajor; | ||
constexpr int AlignmentInputB = | ||
128 / | ||
cutlass::sizeof_bits< | ||
ElementInputB>::value; // Memory access granularity/alignment of B | ||
// matrix in units of elements (up to 16 bytes) | ||
|
||
using ElementScale = WEIGHT_SCALE_DTYPE; | ||
using ElementZeroPoint = WEIGHT_SCALE_DTYPE; | ||
using ElementComputeEpilogue = float; | ||
using ElementAccumulator = float; | ||
|
||
using ElementOutput = cutlass::bfloat16_t; | ||
using LayoutOutput = cutlass::layout::ColumnMajor; | ||
constexpr int AlignmentOutput = | ||
128 / | ||
cutlass::sizeof_bits< | ||
ElementOutput>::value; // Memory access granularity/alignment of C | ||
// matrix in units of elements (up to 16 bytes) | ||
|
||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that | ||
// supports the intended feature | ||
using OperatorClass = cutlass::arch::OpClassTensorOp; | ||
using TileShape = cute::Shape< | ||
cute::Int<TB_M>, | ||
cute::Int<TB_N>, | ||
cute::Int<TB_K>>; // Threadblock-level | ||
// tile size | ||
using ClusterShape = cute::Shape< | ||
cute::Int<TBS_M>, | ||
cute::Int<TBS_N>, | ||
cute::Int<TBS_K>>; // Shape of the | ||
// threadblocks in a | ||
// cluster | ||
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedMixedInput; | ||
using PongSchedule = | ||
cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput; | ||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; | ||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; | ||
using MainLoopSchedule = | ||
cute::conditional_t<PONG, PongSchedule, DefaultSchedule>; | ||
|
||
using CollectiveEpilogue = | ||
typename cutlass::epilogue::collective::CollectiveBuilder< | ||
cutlass::arch::Sm90, | ||
cutlass::arch::OpClassTensorOp, | ||
TileShape, | ||
ClusterShape, | ||
EpilogueTileType, | ||
ElementAccumulator, | ||
ElementAccumulator, | ||
ElementOutput, | ||
LayoutOutput, | ||
AlignmentOutput, | ||
ElementOutput, | ||
LayoutOutput, | ||
AlignmentOutput, | ||
EpilogueSchedule>::CollectiveOp; | ||
|
||
using CollectiveMainloop = | ||
typename cutlass::gemm::collective::CollectiveBuilder< | ||
ArchTag, | ||
OperatorClass, | ||
cute::tuple<ElementInputB, ElementScale, ElementZeroPoint>, | ||
LayoutInputB, | ||
AlignmentInputB, | ||
ElementInputA, | ||
LayoutInputA, | ||
AlignmentInputA, | ||
ElementAccumulator, | ||
TileShape, | ||
ClusterShape, | ||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( | ||
sizeof(typename CollectiveEpilogue::SharedStorage))>, | ||
MainLoopSchedule>::CollectiveOp; | ||
|
||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal< | ||
cute::Shape<int, int, int, int>, | ||
CollectiveMainloop, | ||
CollectiveEpilogue>; | ||
|
||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; | ||
|
||
using StrideInputA = typename Gemm::GemmKernel::StrideA; | ||
using StrideInputB = typename Gemm::GemmKernel::StrideB; | ||
using StrideOutput = typename Gemm::GemmKernel::StrideC; | ||
using StrideS = typename CollectiveMainloop::StrideScale; | ||
|
||
StrideInputA stride_a = cutlass::make_cute_packed_stride( | ||
StrideInputA{}, cute::make_shape(M, K, B)); | ||
StrideInputB stride_b = cutlass::make_cute_packed_stride( | ||
StrideInputB{}, cute::make_shape(N, K, B)); | ||
StrideOutput stride_output = cutlass::make_cute_packed_stride( | ||
StrideOutput{}, cute::make_shape(N, M, B)); | ||
StrideS stride_S = cutlass::make_cute_packed_stride( | ||
StrideS{}, cute::make_shape(N, num_groups, B)); | ||
|
||
typename Gemm::Arguments arguments{ | ||
cutlass::gemm::GemmUniversalMode::kGemm, | ||
{N, M, K, B}, | ||
{reinterpret_cast<ElementInputB*>(WQ.data_ptr()), | ||
stride_b, | ||
reinterpret_cast<ElementInputA*>(X.data_ptr()), | ||
stride_a, | ||
reinterpret_cast<ElementScale*>(w_scale.data_ptr()), | ||
stride_S, | ||
group_size, | ||
reinterpret_cast<ElementZeroPoint*>(w_zp.data_ptr())}, | ||
{{1.0, 0.0}, | ||
(ElementOutput*)Y.data_ptr<at::BFloat16>(), | ||
stride_output, | ||
(ElementOutput*)Y.data_ptr<at::BFloat16>(), | ||
stride_output}}; | ||
|
||
Gemm gemm; | ||
|
||
// Using the arguments, query for extra workspace required for matrix | ||
// multiplication computation | ||
size_t workspace_size = Gemm::get_workspace_size(arguments); | ||
|
||
// Allocate workspace memory | ||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); | ||
|
||
// Check the problem size is supported or not | ||
cutlass::Status status = gemm.can_implement(arguments); | ||
if (status != cutlass::Status::kSuccess) { | ||
throw std::runtime_error("cutlass cannot implement"); | ||
} | ||
|
||
// Initialize CUTLASS kernel with arguments and workspace pointer | ||
status = gemm.initialize(arguments, workspace.get()); | ||
if (status != cutlass::Status::kSuccess) { | ||
throw std::runtime_error("cutlass cannot initialize"); | ||
} | ||
|
||
status = gemm(at::cuda::getCurrentCUDAStream()); | ||
|
||
if (status != cutlass::Status::kSuccess) { | ||
throw std::runtime_error( | ||
std::string("cutlass cannot run") + | ||
cutlass::cutlassGetStatusString(status)); | ||
} | ||
C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||
|
||
return Y; | ||
} | ||
|
||
template <typename WEIGHT_SCALE_DTYPE> | ||
at::Tensor dispatch_bf16i4bf16_rowwise_batched_kernel( | ||
at::Tensor X, // BF16 | ||
at::Tensor WQ, // INT4 | ||
at::Tensor w_scale, | ||
at::Tensor w_zp) { | ||
KernelMode kernel = get_batched_kernel_mode(X, WQ); | ||
if (kernel == KernelMode::Small) { | ||
return bf16i4bf16_rowwise_batched_impl< | ||
64, | ||
128, | ||
64, | ||
2, | ||
1, | ||
1, | ||
true, | ||
WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp); | ||
} else if (kernel == KernelMode::Large) { | ||
return bf16i4bf16_rowwise_batched_impl< | ||
128, | ||
128, | ||
64, | ||
2, | ||
1, | ||
1, | ||
true, | ||
WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp); | ||
} else { | ||
return bf16i4bf16_rowwise_batched_impl< | ||
128, | ||
128, | ||
64, | ||
2, | ||
1, | ||
1, | ||
false, | ||
WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp); | ||
} | ||
} | ||
|
||
at::Tensor bf16i4bf16_rowwise_batched( | ||
at::Tensor X, // BF16 | ||
at::Tensor WQ, // INT4 | ||
at::Tensor w_scale, | ||
at::Tensor w_zp) { | ||
// Check datatypes. | ||
TORCH_CHECK( | ||
(w_scale.dtype() == at::kFloat && w_zp.dtype() == at::kFloat) || | ||
(w_scale.dtype() == at::kHalf && w_zp.dtype() == at::kHalf) || | ||
(w_scale.dtype() == at::kBFloat16 && w_zp.dtype() == at::kBFloat16), | ||
"Weight scale and zero point tensors must be float32, bfloat16, or float16, and dtype of weight scale and zero point tensors must be the same ."); | ||
|
||
if (w_scale.dtype() == at::kFloat) { | ||
return dispatch_bf16i4bf16_rowwise_batched_kernel<float>( | ||
X, WQ, w_scale, w_zp); | ||
} else if (w_scale.dtype() == at::kHalf) { | ||
return dispatch_bf16i4bf16_rowwise_batched_kernel<cutlass::half_t>( | ||
X, WQ, w_scale, w_zp); | ||
} else if (w_scale.dtype() == at::kBFloat16) { | ||
return dispatch_bf16i4bf16_rowwise_batched_kernel<cutlass::bfloat16_t>( | ||
X, WQ, w_scale, w_zp); | ||
} else { | ||
throw std::runtime_error( | ||
"Weight scale and zero point data type not supported in bf16i4bf16_rowwise_batched"); | ||
} | ||
} | ||
|
||
#else | ||
|
||
at::Tensor bf16i4bf16_rowwise_batched( | ||
at::Tensor X, // BF16 | ||
at::Tensor WQ, // INT4 | ||
at::Tensor w_scale, | ||
at::Tensor w_zp) { | ||
throw std::runtime_error( | ||
"CUDA version is older than 12.0"); // requires CUDA>=12 | ||
} | ||
|
||
#endif | ||
|
||
} // namespace fbgemm_gpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.