Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tritonbench] Fix colfax_cutlass flash_attention operator #2401

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torchbenchmark/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@
# colfax Flash Attention V2 for Hopper
torch.ops.load_library("//ai_codesign/gen_ai/cutlass-kernels:fmha_forward_lib")
else:
from userbenchmark.triton.utils import load_library
load_library("colfax_cutlass/fmha_forward_lib.so")
from userbenchmark.triton.loader import load_library
load_library("cutlass_kernels/fmha_forward_lib.so")
colfax_cutlass_fmha = torch.ops.cutlass.fmha_forward
except (ImportError, IOError, AttributeError):
colfax_cutlass_fmha = None
Expand Down
24 changes: 0 additions & 24 deletions userbenchmark/triton/cutlass_kernels/include/fmha_forward.h

This file was deleted.

7 changes: 3 additions & 4 deletions userbenchmark/triton/cutlass_kernels/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
CUDA_HOME = "/usr/local/cuda" if not "CUDA_HOME" in os.environ else os.environ["CUDA_HOME"]
REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent.parent
FBGEMM_PATH = REPO_PATH.joinpath("submodules", "FBGEMM", "fbgemm_gpu")
FBGEMM_CUTLASS_PATH = FBGEMM_PATH.parent.joinpath("third_party", "cutlass")
FBGEMM_CUTLASS_PATH = FBGEMM_PATH.parent.joinpath("external", "cutlass")
COLFAX_CUTLASS_PATH = REPO_PATH.joinpath("submodules", "cutlass-kernels")
COLFAX_CUTLASS_TRITONBENCH_PATH = REPO_PATH.joinpath("userbenchmark", "triton", "cutlass_kernels")

Expand Down Expand Up @@ -37,6 +37,7 @@
COMPILER_FLAGS = [
f"-I{str(COLFAX_CUTLASS_PATH.joinpath('lib').resolve())}",
f"-I{str(COLFAX_CUTLASS_PATH.joinpath('include').resolve())}",
f"-I{str(COLFAX_CUTLASS_PATH.joinpath('src', 'fmha').resolve())}",
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('include').resolve())}",
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('examples', 'commmon').resolve())}",
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('tools', 'util', 'include').resolve())}",
Expand All @@ -63,9 +64,7 @@
"-ldl",
]
FMHA_SOURCES = [
# Source 1
f"{str(COLFAX_CUTLASS_PATH.joinpath('src', 'fmha', 'fmha_forward.cu').resolve())}",
# Source 2
# Source
f"{str(COLFAX_CUTLASS_TRITONBENCH_PATH.joinpath('src', 'fmha', 'register_op.cu').resolve())}",
"-o",
"fmha_forward_lib.so",
Expand Down
37 changes: 17 additions & 20 deletions userbenchmark/triton/cutlass_kernels/src/fmha/register_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@

// #include "autogen/cutlassF.h"
#include "pytorch_utils.h"
#include "fmha_forward.h"
#include "fmha_forward.cu"

template <typename PrecType, typename OutputType, int HEADDIM>
template <typename PrecType, int HEADDIM>
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
fmha_forward(
const int64_t& seq_length,
const int64_t& key_length,
const int64_t& batch,
const at::Tensor& query, // [b, seqlen, num_heads, K]
const at::Tensor& key, // [b, seqlen, num_heads, K]
const at::Tensor& value, // [b, seqlen, num_heads, Kv]
at::Tensor& value, // [b, seqlen, num_heads, Kv]
const float& scale) {
TORCH_CHECK(query.dim() == 4);
TORCH_CHECK(key.dim() == 4);
Expand Down Expand Up @@ -70,7 +70,7 @@ fmha_forward(
query.options().dtype(CutlassToAtenDtype<PrecType>::atScalarType()));
at::Tensor ret = at::empty(
{B, M, num_heads, Kv},
query.options().dtype(CutlassToAtenDtype<OutputType>::atScalarType()));
query.options().dtype(CutlassToAtenDtype<PrecType>::atScalarType()));
using AccumType = float; // AccumType is always float.

at::Tensor devMiOut = at::empty(
Expand All @@ -80,16 +80,16 @@ fmha_forward(
{B, M, num_heads},
query.options().dtype(CutlassToAtenDtype<AccumType>::atScalarType()));

fmhaForwardDevice<PrecType, OutputType, AccumType, HEADDIM>(
fmhaForwardDevice<PrecType, AccumType, HEADDIM>(
seq_length,
key_length,
num_heads,
B,
reinterpret_cast<PrecType const*>(query.data_ptr()),
reinterpret_cast<PrecType const*>(key.data_ptr()),
reinterpret_cast<OutputType const*>(value.data_ptr()),
reinterpret_cast<OutputType*>(S.data_ptr()),
reinterpret_cast<OutputType*>(ret.data_ptr()),
reinterpret_cast<PrecType*>(value.data_ptr()),
reinterpret_cast<PrecType*>(S.data_ptr()),
reinterpret_cast<PrecType*>(ret.data_ptr()),
reinterpret_cast<AccumType*>(devMiOut.data_ptr()),
reinterpret_cast<AccumType*>(devSprimeOut.data_ptr()),
1,
Expand All @@ -99,25 +99,25 @@ fmha_forward(
return std::make_tuple(S, ret, devMiOut, devSprimeOut);
}

template<typename compute_data_type, typename output_data_type>
template<typename compute_data_type>
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
launch_forward(
const int64_t& seq_length,
const int64_t& key_length,
const int64_t& batch,
const at::Tensor& query, // [b, seqlen, num_heads, K]
const at::Tensor& key, // [b, seqlen, num_heads, K]
const at::Tensor& value, // [b, seqlen, num_heads, Kv]
at::Tensor& value, // [b, seqlen, num_heads, Kv]
const double& scale,
const int64_t& Kdim) {
if (Kdim == 64) {
return fmha_forward<compute_data_type, output_data_type, 64>(
return fmha_forward<compute_data_type, 64>(
seq_length, key_length, batch, query, key, value, scale);
} else if (Kdim == 128) {
return fmha_forward<compute_data_type, output_data_type, 128>(
return fmha_forward<compute_data_type, 128>(
seq_length, key_length, batch, query, key, value, scale);
} else if (Kdim == 256) {
return fmha_forward<compute_data_type, output_data_type, 256>(
return fmha_forward<compute_data_type, 256>(
seq_length, key_length, batch, query, key, value, scale);
}
throw std::runtime_error("Kdim wrong");
Expand All @@ -131,18 +131,15 @@ fmha_forward_dispatch(
const int64_t& batch,
const at::Tensor& query, // [b, seqlen, num_heads, K]
const at::Tensor& key, // [b, seqlen, num_heads, K]
const at::Tensor& value, // [b, seqlen, num_heads, Kv]
at::Tensor& value, // [b, seqlen, num_heads, Kv]
const double& scale) {
int64_t Kdim = query.size(-1);

if (query.scalar_type() == at::kHalf){
return launch_forward<cutlass::half_t, cutlass::half_t>(seq_length, key_length, batch, query, key, value, scale, Kdim);
return launch_forward<cutlass::half_t>(seq_length, key_length, batch, query, key, value, scale, Kdim);
}
else if (query.scalar_type() == at::kBFloat16){
return launch_forward<cutlass::bfloat16_t, cutlass::bfloat16_t>(seq_length, key_length, batch, query, key, value, scale, Kdim);
}
else if (query.scalar_type() == at::kFloat8_e4m3fn){
return launch_forward<cutlass::float_e4m3_t, cutlass::bfloat16_t>(seq_length, key_length, batch, query, key, value, scale, Kdim);
return launch_forward<cutlass::bfloat16_t>(seq_length, key_length, batch, query, key, value, scale, Kdim);
}
else {
std::cout << "unsupported data type: " << query.scalar_type() << std::endl;
Expand All @@ -159,7 +156,7 @@ fmha_forward_dispatch_meta(
const int64_t& batch,
const at::Tensor& query, // [b, seqlen, num_heads, K]
const at::Tensor& key, // [b, seqlen, num_heads, K]
const at::Tensor& value, // [b, seqlen, num_heads, Kv]
at::Tensor& value, // [b, seqlen, num_heads, Kv]
const double& scale) {

TORCH_CHECK(query.dim() == 4);
Expand Down
Loading