Skip to content

Commit

Permalink
2024-10-05 nightly release (7a4472a)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Oct 5, 2024
1 parent d1fa47f commit 3602496
Show file tree
Hide file tree
Showing 10 changed files with 581 additions and 37 deletions.
1 change: 1 addition & 0 deletions fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ else()
src/quantize/cutlass_extensions/f8i4bf16_rowwise.cu
src/quantize/cutlass_extensions/i8i8bf16_dynamic.cu
src/quantize/cutlass_extensions/bf16i4bf16_rowwise.cu
src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu
src/quantize/quantize.cu
src/quantize/quantize.cpp)
endif()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cutlass/util/device_memory.h>
#include <cutlass/util/packed_stride.hpp>

// clang-format off
// The fixed ordering of the headers is required for CUTLASS 3.2+
#include <cute/tensor.hpp>
#include <cutlass/gemm/collective/collective_builder.hpp> // @manual
#include <cutlass/gemm/device/gemm_universal_adapter.h> // @manual
#include <cutlass/epilogue/collective/collective_builder.hpp> // @manual
// clang-format on

#include "cutlass_extensions/include/kernel_mode.h"

namespace fbgemm_gpu {

#if CUDART_VERSION >= 12000

template <
int TB_M,
int TB_N,
int TB_K,
int TBS_M,
int TBS_N,
int TBS_K,
bool PONG,
typename WEIGHT_SCALE_DTYPE>
at::Tensor bf16i4bf16_rowwise_batched_impl(
at::Tensor X, // BF16
at::Tensor WQ, // INT4
at::Tensor w_scale,
at::Tensor w_zp) {
// XQ: B x M x K
// WQ: B x N x K
// output: B x M x N
int B = X.size(0);
int M = X.size(1);
int N = WQ.size(1);
int K = X.size(2);

int num_groups = w_scale.size(0) / B;

TORCH_CHECK(X.is_cuda() && X.is_contiguous());
TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous());
TORCH_CHECK(w_scale.is_cuda() && w_scale.is_contiguous());
TORCH_CHECK(w_zp.is_cuda() && w_zp.is_contiguous());
TORCH_CHECK(K >= num_groups && K % num_groups == 0);

int group_size = K / num_groups;

auto Y = at::empty({B, M, N}, X.options().dtype(at::kBFloat16));

using ElementInputA = cutlass::bfloat16_t;
using LayoutInputA = cutlass::layout::ColumnMajor;
constexpr int AlignmentInputA =
128 /
cutlass::sizeof_bits<
ElementInputA>::value; // Memory access granularity/alignment of A
// matrix in units of elements (up to 16 bytes)

using ElementInputB = cutlass::int4b_t;
using LayoutInputB = cutlass::layout::RowMajor;
constexpr int AlignmentInputB =
128 /
cutlass::sizeof_bits<
ElementInputB>::value; // Memory access granularity/alignment of B
// matrix in units of elements (up to 16 bytes)

using ElementScale = WEIGHT_SCALE_DTYPE;
using ElementZeroPoint = WEIGHT_SCALE_DTYPE;
using ElementComputeEpilogue = float;
using ElementAccumulator = float;

using ElementOutput = cutlass::bfloat16_t;
using LayoutOutput = cutlass::layout::ColumnMajor;
constexpr int AlignmentOutput =
128 /
cutlass::sizeof_bits<
ElementOutput>::value; // Memory access granularity/alignment of C
// matrix in units of elements (up to 16 bytes)

using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that
// supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp;
using TileShape = cute::Shape<
cute::Int<TB_M>,
cute::Int<TB_N>,
cute::Int<TB_K>>; // Threadblock-level
// tile size
using ClusterShape = cute::Shape<
cute::Int<TBS_M>,
cute::Int<TBS_N>,
cute::Int<TBS_K>>; // Shape of the
// threadblocks in a
// cluster
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedMixedInput;
using PongSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using MainLoopSchedule =
cute::conditional_t<PONG, PongSchedule, DefaultSchedule>;

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90,
cutlass::arch::OpClassTensorOp,
TileShape,
ClusterShape,
EpilogueTileType,
ElementAccumulator,
ElementAccumulator,
ElementOutput,
LayoutOutput,
AlignmentOutput,
ElementOutput,
LayoutOutput,
AlignmentOutput,
EpilogueSchedule>::CollectiveOp;

using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
cute::tuple<ElementInputB, ElementScale, ElementZeroPoint>,
LayoutInputB,
AlignmentInputB,
ElementInputA,
LayoutInputA,
AlignmentInputA,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainLoopSchedule>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int, int>,
CollectiveMainloop,
CollectiveEpilogue>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

using StrideInputA = typename Gemm::GemmKernel::StrideA;
using StrideInputB = typename Gemm::GemmKernel::StrideB;
using StrideOutput = typename Gemm::GemmKernel::StrideC;
using StrideS = typename CollectiveMainloop::StrideScale;

StrideInputA stride_a = cutlass::make_cute_packed_stride(
StrideInputA{}, cute::make_shape(M, K, B));
StrideInputB stride_b = cutlass::make_cute_packed_stride(
StrideInputB{}, cute::make_shape(N, K, B));
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput{}, cute::make_shape(N, M, B));
StrideS stride_S = cutlass::make_cute_packed_stride(
StrideS{}, cute::make_shape(N, num_groups, B));

typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{N, M, K, B},
{reinterpret_cast<ElementInputB*>(WQ.data_ptr()),
stride_b,
reinterpret_cast<ElementInputA*>(X.data_ptr()),
stride_a,
reinterpret_cast<ElementScale*>(w_scale.data_ptr()),
stride_S,
group_size,
reinterpret_cast<ElementZeroPoint*>(w_zp.data_ptr())},
{{1.0, 0.0},
(ElementOutput*)Y.data_ptr<at::BFloat16>(),
stride_output,
(ElementOutput*)Y.data_ptr<at::BFloat16>(),
stride_output}};

Gemm gemm;

// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);

// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

// Check the problem size is supported or not
cutlass::Status status = gemm.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
}

// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}

status = gemm(at::cuda::getCurrentCUDAStream());

if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(
std::string("cutlass cannot run") +
cutlass::cutlassGetStatusString(status));
}
C10_CUDA_KERNEL_LAUNCH_CHECK();

return Y;
}

template <typename WEIGHT_SCALE_DTYPE>
at::Tensor dispatch_bf16i4bf16_rowwise_batched_kernel(
at::Tensor X, // BF16
at::Tensor WQ, // INT4
at::Tensor w_scale,
at::Tensor w_zp) {
KernelMode kernel = get_batched_kernel_mode(X, WQ);
if (kernel == KernelMode::Small) {
return bf16i4bf16_rowwise_batched_impl<
64,
128,
64,
2,
1,
1,
true,
WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp);
} else if (kernel == KernelMode::Large) {
return bf16i4bf16_rowwise_batched_impl<
128,
128,
64,
2,
1,
1,
true,
WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp);
} else {
return bf16i4bf16_rowwise_batched_impl<
128,
128,
64,
2,
1,
1,
false,
WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp);
}
}

at::Tensor bf16i4bf16_rowwise_batched(
at::Tensor X, // BF16
at::Tensor WQ, // INT4
at::Tensor w_scale,
at::Tensor w_zp) {
// Check datatypes.
TORCH_CHECK(
(w_scale.dtype() == at::kFloat && w_zp.dtype() == at::kFloat) ||
(w_scale.dtype() == at::kHalf && w_zp.dtype() == at::kHalf) ||
(w_scale.dtype() == at::kBFloat16 && w_zp.dtype() == at::kBFloat16),
"Weight scale and zero point tensors must be float32, bfloat16, or float16, and dtype of weight scale and zero point tensors must be the same .");

if (w_scale.dtype() == at::kFloat) {
return dispatch_bf16i4bf16_rowwise_batched_kernel<float>(
X, WQ, w_scale, w_zp);
} else if (w_scale.dtype() == at::kHalf) {
return dispatch_bf16i4bf16_rowwise_batched_kernel<cutlass::half_t>(
X, WQ, w_scale, w_zp);
} else if (w_scale.dtype() == at::kBFloat16) {
return dispatch_bf16i4bf16_rowwise_batched_kernel<cutlass::bfloat16_t>(
X, WQ, w_scale, w_zp);
} else {
throw std::runtime_error(
"Weight scale and zero point data type not supported in bf16i4bf16_rowwise_batched");
}
}

#else

at::Tensor bf16i4bf16_rowwise_batched(
at::Tensor X, // BF16
at::Tensor WQ, // INT4
at::Tensor w_scale,
at::Tensor w_zp) {
throw std::runtime_error(
"CUDA version is older than 12.0"); // requires CUDA>=12
}

#endif

} // namespace fbgemm_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,19 @@ at::Tensor dispatch_fp8_rowwise_batched_kernel(
UseBias,
InputDType,
BiasDType>(XQ, WQ, x_scale, w_scale, bias, output);
} else if (kernel == KernelMode::Medium) {
return f8f8bf16_rowwise_batched_impl<
64,
128,
128,
1,
2,
1,
true,
FastAccum,
UseBias,
InputDType,
BiasDType>(XQ, WQ, x_scale, w_scale, bias, output);
} else if (kernel == KernelMode::Large) {
return f8f8bf16_rowwise_batched_impl<
128,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

namespace fbgemm_gpu {

enum class KernelMode { Small, Large, Default };
enum class KernelMode { Small, Medium, Large, Default };

inline KernelMode get_kernel_mode(at::Tensor XQ, at::Tensor WQ) {
auto M = XQ.size(0);
Expand All @@ -37,14 +37,14 @@ inline KernelMode get_batched_kernel_mode(at::Tensor XQ, at::Tensor WQ) {
auto K = XQ.size(2);
auto N = WQ.size(1);
auto BM = B * M;
auto BN = B * N;
auto BK = B * K;
// Use a large kernel if at least two shapes are large....
bool use_large_kernel =
((BM >= 2048 && BK >= 2048) || (BM >= 2048 && BK >= 2048) ||
(BK >= 2048 && BN >= 2048));
if (BM <= 128 || BN <= 128) {
// Heuristic to determine kernel mode
bool use_medium_kernel =
((BM <= 512 && ((N <= 8192 && K < 8192) || (N < 8192 && K <= 8192))));
bool use_large_kernel = ((BM > 512 && (N >= 1024 || K >= 1024)));
if (BM <= 128 || N <= 128) {
return KernelMode::Small;
} else if (use_medium_kernel) {
return KernelMode::Medium;
} else if (use_large_kernel) {
return KernelMode::Large;
} else {
Expand Down
Loading

0 comments on commit 3602496

Please sign in to comment.