Skip to content

Commit

Permalink
pytest mostly works except for sgmv
Browse files Browse the repository at this point in the history
  • Loading branch information
abcdabcd987 committed Apr 25, 2024
1 parent 8a75f02 commit a9b4fe7
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 52 deletions.
51 changes: 23 additions & 28 deletions csrc/flashinfer_adapter/flashinfer_all.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,15 @@ inline T* alloc_from_buf(void** buf, int n) {
template <typename T>
bool FlashInferBatchPrefillKernel(T* o, T* q, int32_t* qo_indptr, T** kv_ptrs,
int32_t* kv_indptr, int32_t* last_page_offset,
void* tmpbuf, int head_dim, int num_layers,
int layer_idx, int group_size,
void* tmpbuf, int head_dim, int group_size,
int num_kv_heads, int page_size,
int batch_size) {
return DISPATCH_page_size(page_size, [&] {
return DISPATCH_group_size(group_size, [&] {
return DISPATCH_head_dim(head_dim, [&] {
using namespace flashinfer;
BatchPrefillHandler handler;
paged_kv_t<PageStorage::kPointer, QKVLayout::kNHD, T, int32_t> paged_kv(
paged_kv_t<PageStorage::kPointer, QKVLayout::kHND, T, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size, kv_ptrs, kv_indptr,
last_page_offset);
int num_qo_heads = num_kv_heads * group_size;
Expand All @@ -78,7 +77,7 @@ bool FlashInferBatchPrefillKernel(T* o, T* q, int32_t* qo_indptr, T** kv_ptrs,
return false;
}
status = BatchPrefillWithPagedKVCacheWrapperDispatched<
PageStorage::kPointer, QKVLayout::kNHD, PAGE_SIZE, GROUP_SIZE,
PageStorage::kPointer, QKVLayout::kHND, PAGE_SIZE, GROUP_SIZE,
HEAD_DIM, PosEncodingMode::kRoPELlama, allow_fp16_qk_reduction,
causal, T, T, int32_t>(&handler, q, qo_indptr, q_offset, paged_kv,
o, lse, sm_scale, rope_scale, rope_theta,
Expand All @@ -99,14 +98,13 @@ bool FlashInferBatchPrefillKernel(T* o, T* q, int32_t* qo_indptr, T** kv_ptrs,
template <typename T>
bool FlashInferBatchDecodeKernel(T* o, T* q, T** kv_ptrs, int32_t* kv_indptr,
int32_t* last_page_offset, void* tmpbuf,
int head_dim, int num_layers, int layer_idx,
int group_size, int num_kv_heads,
int head_dim, int group_size, int num_kv_heads,
int page_size, int batch_size) {
return DISPATCH_group_size(group_size, [&] {
return DISPATCH_head_dim(head_dim, [&] {
using namespace flashinfer;
BatchDecodeHandler handler;
paged_kv_t<PageStorage::kPointer, QKVLayout::kNHD, T, int32_t> paged_kv(
paged_kv_t<PageStorage::kPointer, QKVLayout::kHND, T, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size, kv_ptrs, kv_indptr,
last_page_offset);
int num_qo_heads = num_kv_heads * group_size;
Expand All @@ -117,18 +115,18 @@ bool FlashInferBatchDecodeKernel(T* o, T* q, T** kv_ptrs, int32_t* kv_indptr,
float rope_theta = 1e4;
cudaStream_t stream = nullptr;
size_t workspace_size_in_bytes = 32 * 1024 * 1024;
auto status = handler.BeginForward<PageStorage::kPointer, QKVLayout::kNHD,
T, T, int32_t>(
auto status = handler.BeginForwardDispatched<
GROUP_SIZE, HEAD_DIM, PageStorage::kPointer, QKVLayout::kHND,
PosEncodingMode::kRoPELlama, T, T, int32_t>(
tmpbuf, workspace_size_in_bytes, kv_indptr, last_page_offset,
batch_size, num_qo_heads, num_kv_heads, head_dim, page_size,
PosEncodingMode::kRoPELlama);
batch_size, num_qo_heads, page_size);
if (status != cudaSuccess) {
fprintf(stderr, "batch_decode failed in handler.BeginForward: %s\n",
cudaGetErrorString(status));
return false;
}
status = BatchDecodeWithPagedKVCacheWrapperDispatched<
PageStorage::kPointer, QKVLayout::kNHD, GROUP_SIZE, HEAD_DIM,
PageStorage::kPointer, QKVLayout::kHND, GROUP_SIZE, HEAD_DIM,
PosEncodingMode::kRoPELlama, T, T, int32_t>(
&handler, q, q_offset, paged_kv, o, lse, sm_scale, rope_scale,
rope_theta, stream);
Expand All @@ -148,10 +146,9 @@ template <typename T>
bool FlashInferInitKvKernel(T** kv_ptrs, int32_t* kv_indptr,
int32_t* last_page_offset, T* key, T* value,
int32_t* seqlen_indptr, int head_dim,
int num_layers, int layer_idx, int num_kv_heads,
int page_size, int batch_size) {
int num_kv_heads, int page_size, int batch_size) {
using namespace flashinfer;
paged_kv_t<PageStorage::kPointer, QKVLayout::kNHD, T, int32_t> paged_kv(
paged_kv_t<PageStorage::kPointer, QKVLayout::kHND, T, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size, kv_ptrs, kv_indptr,
last_page_offset);
cudaStream_t stream = nullptr;
Expand All @@ -167,10 +164,10 @@ bool FlashInferInitKvKernel(T** kv_ptrs, int32_t* kv_indptr,
template <typename T>
bool FlashInferAppendKvKernel(T** kv_ptrs, int32_t* kv_indptr,
int32_t* last_page_offset, T* key, T* value,
int head_dim, int num_layers, int layer_idx,
int num_kv_heads, int page_size, int batch_size) {
int head_dim, int num_kv_heads, int page_size,
int batch_size) {
using namespace flashinfer;
paged_kv_t<PageStorage::kPointer, QKVLayout::kNHD, T, int32_t> paged_kv(
paged_kv_t<PageStorage::kPointer, QKVLayout::kHND, T, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size, kv_ptrs, kv_indptr,
last_page_offset);
cudaStream_t stream = nullptr;
Expand All @@ -186,33 +183,31 @@ bool FlashInferAppendKvKernel(T** kv_ptrs, int32_t* kv_indptr,
#define INST_FlashInferBatchPrefillKernel(T) \
template bool FlashInferBatchPrefillKernel<T>( \
T * o, T * q, int32_t * qo_indptr, T * *kv_ptrs, int32_t * kv_indptr, \
int32_t * last_page_offset, void* tmpbuf, int head_dim, int num_layers, \
int layer_idx, int group_size, int num_kv_heads, int page_size, \
int batch_size);
int32_t * last_page_offset, void* tmpbuf, int head_dim, int group_size, \
int num_kv_heads, int page_size, int batch_size);
INST_FlashInferBatchPrefillKernel(nv_half);
INST_FlashInferBatchPrefillKernel(nv_bfloat16);

#define INST_FlashInferBatchDecodeKernel(T) \
template bool FlashInferBatchDecodeKernel<T>( \
T * o, T * q, T * *kv_ptrs, int32_t * kv_indptr, \
int32_t * last_page_offset, void* tmpbuf, int head_dim, int num_layers, \
int layer_idx, int group_size, int num_kv_heads, int page_size, \
int batch_size);
int32_t * last_page_offset, void* tmpbuf, int head_dim, int group_size, \
int num_kv_heads, int page_size, int batch_size);
INST_FlashInferBatchDecodeKernel(nv_half);
INST_FlashInferBatchDecodeKernel(nv_bfloat16);

#define INST_FlashInferInitKvKernel(T) \
template bool FlashInferInitKvKernel<T>( \
T * *kv_ptrs, int32_t * kv_indptr, int32_t * last_page_offset, T * key, \
T * value, int32_t * seqlen_indptr, int head_dim, int num_layers, \
int layer_idx, int num_kv_heads, int page_size, int batch_size);
T * value, int32_t * seqlen_indptr, int head_dim, int num_kv_heads, \
int page_size, int batch_size);
INST_FlashInferInitKvKernel(nv_half);
INST_FlashInferInitKvKernel(nv_bfloat16);

#define INST_FlashInferAppendKvKernel(T) \
template bool FlashInferAppendKvKernel<T>( \
T * *kv_ptrs, int32_t * kv_indptr, int32_t * last_page_offset, T * key, \
T * value, int head_dim, int num_layers, int layer_idx, \
int num_kv_heads, int page_size, int batch_size);
T * value, int head_dim, int num_kv_heads, int page_size, \
int batch_size);
INST_FlashInferAppendKvKernel(nv_half);
INST_FlashInferAppendKvKernel(nv_bfloat16);
13 changes: 5 additions & 8 deletions csrc/flashinfer_adapter/flashinfer_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,24 @@
template <typename T>
bool FlashInferBatchPrefillKernel(T* o, T* q, int32_t* qo_indptr, T** kv_ptrs,
int32_t* kv_indptr, int32_t* last_page_offset,
void* tmpbuf, int head_dim, int num_layers,
int layer_idx, int group_size,
void* tmpbuf, int head_dim, int group_size,
int num_kv_heads, int page_size,
int batch_size);

template <typename T>
bool FlashInferBatchDecodeKernel(T* o, T* q, T** kv_ptrs, int32_t* kv_indptr,
int32_t* last_page_offset, void* tmpbuf,
int head_dim, int num_layers, int layer_idx,
int group_size, int num_kv_heads,
int head_dim, int group_size, int num_kv_heads,
int page_size, int batch_size);

template <typename T>
bool FlashInferInitKvKernel(T** kv_ptrs, int32_t* kv_indptr,
int32_t* last_page_offset, T* key, T* value,
int32_t* seqlen_indptr, int head_dim,
int num_layers, int layer_idx, int num_kv_heads,
int page_size, int batch_size);
int num_kv_heads, int page_size, int batch_size);

template <typename T>
bool FlashInferAppendKvKernel(T** kv_ptrs, int32_t* kv_indptr,
int32_t* last_page_offset, T* key, T* value,
int head_dim, int num_layers, int layer_idx,
int num_kv_heads, int page_size, int batch_size);
int head_dim, int num_kv_heads, int page_size,
int batch_size);
8 changes: 4 additions & 4 deletions csrc/flashinfer_adapter/flashinfer_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
#define INST_BatchPrefill_X(T, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, NUM_FRAGS_X) \
namespace flashinfer { \
template cudaError_t BatchPrefillWithPagedKVCacheDispatched< \
PageStorage::kPointer, QKVLayout::kNHD, NUM_FRAGS_X, PAGE_SIZE, \
PageStorage::kPointer, QKVLayout::kHND, NUM_FRAGS_X, PAGE_SIZE, \
GROUP_SIZE, HEAD_DIM, PosEncodingMode::kRoPELlama, \
/* ALLOW_FP16_QK_REDUCTION= */ false, /* CAUSAL= */ true, T, T, \
int32_t>( \
T * q, int32_t* request_indices, int32_t* tile_indices, \
int32_t* qo_indptr, int32_t* q_offset, \
paged_kv_t<PageStorage::kPointer, QKVLayout::kNHD, T, int32_t> paged_kv, \
paged_kv_t<PageStorage::kPointer, QKVLayout::kHND, T, int32_t> paged_kv, \
T* o, float* tmp, float* lse, uint32_t num_qo_tiles, float sm_scale, \
float rope_scale, float rope_theta, cudaStream_t stream); \
}
Expand All @@ -24,10 +24,10 @@
#define INST_BatchDecode(T, PAGE_SIZE, GROUP_SIZE, HEAD_DIM) \
namespace flashinfer { \
template cudaError_t BatchDecodeWithPagedKVCacheDispatched< \
GROUP_SIZE, HEAD_DIM, PageStorage::kPointer, QKVLayout::kNHD, \
GROUP_SIZE, HEAD_DIM, PageStorage::kPointer, QKVLayout::kHND, \
PosEncodingMode::kRoPELlama, T, T, int32_t>( \
T * q, int32_t* q_offset, \
paged_kv_t<PageStorage::kPointer, QKVLayout::kNHD, T, int32_t> paged_kv, \
paged_kv_t<PageStorage::kPointer, QKVLayout::kHND, T, int32_t> paged_kv, \
kv_partition_info_t<int32_t> kv_partition_info, T* o, T* tmp, \
float* lse, float sm_scale, float rope_scale, float rope_theta, \
cudaStream_t stream); \
Expand Down
41 changes: 30 additions & 11 deletions csrc/punica_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
#define CHECK_GE(a, b) \
TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)

torch::Tensor GetLayerKvPtrs(torch::Tensor kv_ptrs, int num_layers,
int layer_idx, int num_kv_heads, int page_size,
int head_dim, int kv_elem_size) {
int layer_stride = 2 * num_kv_heads * page_size * head_dim;
return kv_ptrs + layer_idx * layer_stride * kv_elem_size;
}

//====== dispatch pytorch dtype ======

#define _DISPATCH_SWITCH(cond, ...) \
Expand Down Expand Up @@ -106,13 +113,16 @@ void batch_prefill(torch::Tensor o, torch::Tensor q, torch::Tensor qo_indptr,
CHECK_GE(tmpbuf.nbytes(), 64 << 20);

bool ok = DISPATCH_TORCH_DTYPE(q.scalar_type(), [&] {
auto layer_kv_ptrs =
GetLayerKvPtrs(kv_ptrs, num_layers, layer_idx, num_kv_heads, page_size,
head_dim, sizeof(c_type));
return FlashInferBatchPrefillKernel(
static_cast<c_type*>(o.data_ptr()), static_cast<c_type*>(q.data_ptr()),
qo_indptr.data_ptr<int32_t>(),
reinterpret_cast<c_type**>(kv_ptrs.data_ptr<int64_t>()),
reinterpret_cast<c_type**>(layer_kv_ptrs.data_ptr<int64_t>()),
kv_indptr.data_ptr<int32_t>(), last_page_offset.data_ptr<int32_t>(),
tmpbuf.data_ptr(), head_dim, num_layers, layer_idx, group_size,
num_kv_heads, page_size, batch_size);
tmpbuf.data_ptr(), head_dim, group_size, num_kv_heads, page_size,
batch_size);
});
TORCH_CHECK(ok, "No suitable kernel.", " dtype=", q.scalar_type(),
" page_size=", page_size, " group_size=", group_size,
Expand Down Expand Up @@ -147,12 +157,15 @@ void batch_decode(torch::Tensor o, torch::Tensor q, torch::Tensor kv_ptrs,
CHECK_GE(tmpbuf.nbytes(), 64 << 20);

bool ok = DISPATCH_TORCH_DTYPE(q.scalar_type(), [&] {
auto layer_kv_ptrs =
GetLayerKvPtrs(kv_ptrs, num_layers, layer_idx, num_kv_heads, page_size,
head_dim, sizeof(c_type));
return FlashInferBatchDecodeKernel(
static_cast<c_type*>(o.data_ptr()), static_cast<c_type*>(q.data_ptr()),
reinterpret_cast<c_type**>(kv_ptrs.data_ptr<int64_t>()),
reinterpret_cast<c_type**>(layer_kv_ptrs.data_ptr<int64_t>()),
kv_indptr.data_ptr<int32_t>(), last_page_offset.data_ptr<int32_t>(),
tmpbuf.data_ptr(), head_dim, num_layers, layer_idx, group_size,
num_kv_heads, page_size, batch_size);
tmpbuf.data_ptr(), head_dim, group_size, num_kv_heads, page_size,
batch_size);
});
TORCH_CHECK(ok, "No suitable kernel.", " dtype=", q.scalar_type(),
" page_size=", page_size, " group_size=", group_size,
Expand Down Expand Up @@ -184,12 +197,15 @@ void init_kv(torch::Tensor kv_ptrs, torch::Tensor kv_indptr,
CHECK_SHAPE(k, v);

bool ok = DISPATCH_TORCH_DTYPE(k.scalar_type(), [&] {
auto layer_kv_ptrs =
GetLayerKvPtrs(kv_ptrs, num_layers, layer_idx, num_kv_heads, page_size,
head_dim, sizeof(c_type));
return FlashInferInitKvKernel<c_type>(
reinterpret_cast<c_type**>(kv_ptrs.data_ptr<int64_t>()),
reinterpret_cast<c_type**>(layer_kv_ptrs.data_ptr<int64_t>()),
kv_indptr.data_ptr<int32_t>(), last_page_offset.data_ptr<int32_t>(),
static_cast<c_type*>(k.data_ptr()), static_cast<c_type*>(v.data_ptr()),
seqlen_indptr.data_ptr<int32_t>(), head_dim, num_layers, layer_idx,
num_kv_heads, page_size, batch_size);
seqlen_indptr.data_ptr<int32_t>(), head_dim, num_kv_heads, page_size,
batch_size);
});
TORCH_CHECK(ok, "Error in init_kv.", " dtype=", k.scalar_type(),
" head_dim=", head_dim);
Expand Down Expand Up @@ -218,11 +234,14 @@ void append_kv(torch::Tensor kv_ptrs, torch::Tensor kv_indptr,
CHECK_SHAPE(k, v);

bool ok = DISPATCH_TORCH_DTYPE(k.scalar_type(), [&] {
auto layer_kv_ptrs =
GetLayerKvPtrs(kv_ptrs, num_layers, layer_idx, num_kv_heads, page_size,
head_dim, sizeof(c_type));
return FlashInferAppendKvKernel<c_type>(
reinterpret_cast<c_type**>(kv_ptrs.data_ptr<int64_t>()),
reinterpret_cast<c_type**>(layer_kv_ptrs.data_ptr<int64_t>()),
kv_indptr.data_ptr<int32_t>(), last_page_offset.data_ptr<int32_t>(),
static_cast<c_type*>(k.data_ptr()), static_cast<c_type*>(v.data_ptr()),
head_dim, num_layers, layer_idx, num_kv_heads, page_size, batch_size);
head_dim, num_kv_heads, page_size, batch_size);
});
TORCH_CHECK(ok, "Error in append_kv.", " dtype=", k.scalar_type(),
" head_dim=", head_dim);
Expand Down

0 comments on commit a9b4fe7

Please sign in to comment.