Skip to content

Commit

Permalink
Avoid preallocating buffer for kvcache
Browse files Browse the repository at this point in the history
use `flashinfer::PageStorage::kPointer`
  • Loading branch information
abcdabcd987 committed Nov 23, 2023
1 parent f2fc15c commit a737045
Show file tree
Hide file tree
Showing 22 changed files with 608 additions and 363 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ wandb/
model/
*.env
.coverage
.vscode/
11 changes: 5 additions & 6 deletions benchmarks/bench_batch_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
self,
num_heads: int,
head_dim: int,
block_len: int,
page_len: int,
seqlens: list[int],
dtype: str,
device: torch.device,
Expand All @@ -29,8 +29,7 @@ def __init__(
num_layers=1,
num_heads=num_heads,
head_dim=head_dim,
capacity=sum((l + block_len - 1) // block_len for l in seqlens),
block_len=block_len,
page_len=page_len,
dtype=dtype,
device=device,
)
Expand Down Expand Up @@ -76,15 +75,15 @@ def bench_batch_decode(f):
seqlen_ = list(reversed(range(2048, 0, -64)))
dtype = "float16"
device = torch.device("cuda:0")
block_len = 16
page_len = 16
head_dim = 128

all_ = list(itertools.product(num_heads_, seqlen_, batch_size_))
for num_heads, seqlen, batch_size in (pbar := tqdm(all_)):
setup = dict(
num_heads=num_heads,
head_dim=head_dim,
block_len=block_len,
page_len=page_len,
seqlen=seqlen,
batch_size=batch_size,
)
Expand All @@ -94,7 +93,7 @@ def bench_batch_decode(f):
res = batch_decode_Resources(
num_heads=num_heads,
head_dim=head_dim,
block_len=block_len,
page_len=page_len,
seqlens=[seqlen] * batch_size,
dtype=dtype,
device=device,
Expand Down
11 changes: 5 additions & 6 deletions benchmarks/bench_layer_lora_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class layer_lora_decode_Resources:
def __init__(
self,
config: LlamaConfig,
block_len: int,
page_len: int,
lora_rank: int,
lora_popularity: int,
seqlens: list[int],
Expand All @@ -39,8 +39,7 @@ def __init__(
num_layers=1,
num_heads=num_heads,
head_dim=head_dim,
capacity=sum((l + block_len - 1) // block_len for l in seqlens),
block_len=block_len,
page_len=page_len,
dtype=dtype,
device=device,
)
Expand Down Expand Up @@ -77,7 +76,7 @@ def bench_layer_lora_decode(f):
seqlen_ = list(reversed(range(2048, 0, -64)))
dtype = torch.float16
device = torch.device("cuda:0")
block_len = 16
page_len = 16
lora_rank = 16
head_dim = 128

Expand Down Expand Up @@ -105,7 +104,7 @@ def bench_layer_lora_decode(f):
gc_torch()
res = layer_lora_decode_Resources(
config=config,
block_len=block_len,
page_len=page_len,
lora_rank=lora_rank,
lora_popularity=pop,
seqlens=[seqlen] * batch_size,
Expand All @@ -116,7 +115,7 @@ def bench_layer_lora_decode(f):
num_heads=num_heads,
head_dim=head_dim,
intermediate_size=intermediate_size,
block_len=block_len,
page_len=page_len,
lora_rank=lora_rank,
lora_popularity=pop,
num_lora_models=res.num_lora_models,
Expand Down
11 changes: 5 additions & 6 deletions benchmarks/bench_model_lora_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class model_lora_decode_Resources:
def __init__(
self,
config: LlamaConfig,
block_len: int,
page_len: int,
lora_rank: int,
seqlens: list[int],
dtype: torch.dtype,
Expand All @@ -40,8 +40,7 @@ def __init__(
num_layers=config.num_hidden_layers,
num_heads=num_heads,
head_dim=head_dim,
capacity=sum((l + block_len - 1) // block_len for l in seqlens),
block_len=block_len,
page_len=page_len,
dtype=dtype,
device=device,
)
Expand Down Expand Up @@ -80,7 +79,7 @@ def bench_model_lora_decode(f):
seqlen_ = list(reversed(range(2048, 0, -64)))
dtype = torch.float16
device = torch.device("cuda:0")
block_len = 16
page_len = 16
lora_rank = 16
head_dim = 128

Expand Down Expand Up @@ -113,7 +112,7 @@ def bench_model_lora_decode(f):
gc_torch()
res = model_lora_decode_Resources(
config=config,
block_len=block_len,
page_len=page_len,
lora_rank=lora_rank,
seqlens=[seqlen] * batch_size,
dtype=dtype,
Expand All @@ -124,7 +123,7 @@ def bench_model_lora_decode(f):
head_dim=head_dim,
num_layers=num_layers,
intermediate_size=intermediate_size,
block_len=block_len,
page_len=page_len,
lora_rank=lora_rank,
num_lora_models=res.num_lora_models,
seqlen=seqlen,
Expand Down
11 changes: 5 additions & 6 deletions benchmarks/bench_model_prefill_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class model_Resources:
def __init__(
self,
config: LlamaConfig,
block_len: int,
page_len: int,
seqlen: int,
prefills: int,
decodes: int,
Expand All @@ -32,8 +32,7 @@ def __init__(
num_layers=config.num_hidden_layers,
num_heads=num_heads,
head_dim=head_dim,
capacity=(seqlen + block_len - 1) // block_len * (prefills + decodes),
block_len=block_len,
page_len=page_len,
dtype=dtype,
device=device,
)
Expand Down Expand Up @@ -63,7 +62,7 @@ def bench_model_prefill_decode(f):
seqlen_ = [128, 512, 1024, 1536, 2048]
dtype = torch.float16
device = torch.device("cuda:0")
block_len = 16
page_len = 16
head_dim = 128

all_ = list(
Expand Down Expand Up @@ -99,7 +98,7 @@ def bench_model_prefill_decode(f):
gc_torch()
res = model_Resources(
config=config,
block_len=block_len,
page_len=page_len,
seqlen=seqlen,
prefills=batch_size * prefill,
decodes=batch_size * decode,
Expand All @@ -111,7 +110,7 @@ def bench_model_prefill_decode(f):
head_dim=head_dim,
num_layers=num_layers,
intermediate_size=intermediate_size,
block_len=block_len,
page_len=page_len,
seqlen=seqlen,
prefills=batch_size * prefill,
decodes=batch_size * decode,
Expand Down
3 changes: 1 addition & 2 deletions benchmarks/bench_textgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ def textgen_punica(
num_layers=model_cfg.num_layers,
num_heads=model_cfg.num_kv_heads,
head_dim=model_cfg.hidden_size // model_cfg.num_qo_heads,
capacity=textgen_cfg.batch_size * 2048 // 16,
block_len=16,
page_len=16,
dtype=dtype,
device=device,
)
Expand Down
3 changes: 1 addition & 2 deletions benchmarks/bench_textgen_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ def lora_punica(
num_layers=model_cfg.num_layers,
num_heads=model_cfg.num_kv_heads,
head_dim=model_cfg.hidden_size // model_cfg.num_qo_heads,
capacity=textgen_cfg.batch_size * 2048 // 16,
block_len=16,
page_len=16,
dtype=dtype,
device=device,
)
Expand Down
77 changes: 38 additions & 39 deletions csrc/flashinfer_adapter/flashinfer_all.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,83 +11,82 @@ using flashinfer::PageStorage;
using flashinfer::RotaryMode;

template <typename T>
void FlashInferBatchDecodeKernel(T* o, T* q, T* kv_data, int32_t* kv_indptr,
int32_t* kv_indicies,
int32_t* last_page_offset, int head_dim,
int num_layers, int layer_idx,
void FlashInferBatchDecodeKernel(T* o, T* q, T** kv_ptrs, int32_t* kv_indptr,
int32_t* last_page_offset, void* kv_aux,
int head_dim, int num_layers, int layer_idx,
int num_qo_heads, int num_kv_heads,
int page_size, int batch_size) {
paged_kv_t<PageStorage::kIndices, T, int32_t> paged_kv(
paged_kv_t<PageStorage::kPointer, T, int32_t> paged_kv(
num_layers, layer_idx, num_kv_heads, page_size, head_dim, batch_size,
kv_data, kv_indicies, kv_indptr, last_page_offset);
kv_ptrs, kv_indptr, last_page_offset, (int32_t*)kv_aux);
flashinfer::BatchDecodeWithPagedKVCache(q, paged_kv, o, nullptr, num_qo_heads,
RotaryMode::kLlama);
}

template <int head_dim, typename T>
void FlashInferInitKvKernel(T* kv_data, int32_t* kv_indptr,
int32_t* kv_indicies, int32_t* last_page_offset,
T* key, T* value, int32_t* seqlen_indptr,
int num_layers, int layer_idx, int num_kv_heads,
int page_size, int batch_size) {
paged_kv_t<PageStorage::kIndices, T, int32_t> paged_kv(
void FlashInferInitKvKernel(T** kv_ptrs, int32_t* kv_indptr,
int32_t* last_page_offset, T* key, T* value,
int32_t* seqlen_indptr, int num_layers,
int layer_idx, int num_kv_heads, int page_size,
int batch_size) {
paged_kv_t<PageStorage::kPointer, T, int32_t> paged_kv(
num_layers, layer_idx, num_kv_heads, page_size, head_dim, batch_size,
kv_data, kv_indicies, kv_indptr, last_page_offset);
kv_ptrs, kv_indptr, last_page_offset);

constexpr size_t vec_size =
std::max(16 / sizeof(T), static_cast<size_t>(head_dim / 32));
constexpr size_t bdx = head_dim / vec_size;
constexpr size_t bdy = 128 / bdx;
constexpr size_t bdy = 1;
dim3 nblks(paged_kv.batch_size * paged_kv.num_heads / bdy);
dim3 nthrs(bdx, bdy);
flashinfer::AppendPagedKVCachePrefillKernel<head_dim, vec_size, bdx, bdy,
PageStorage::kIndices, T, int32_t>
PageStorage::kPointer, T, int32_t>
<<<nblks, nthrs>>>(paged_kv, key, value, seqlen_indptr);
}

template <int head_dim, typename T>
void FlashInferAppendKvKernel(T* kv_data, int32_t* kv_indptr,
int32_t* kv_indicies, int32_t* last_page_offset,
T* key, T* value, int num_layers, int layer_idx,
int num_kv_heads, int page_size, int batch_size) {
paged_kv_t<PageStorage::kIndices, T, int32_t> paged_kv(
void FlashInferAppendKvKernel(T** kv_ptrs, int32_t* kv_indptr,
int32_t* last_page_offset, T* key, T* value,
int num_layers, int layer_idx, int num_kv_heads,
int page_size, int batch_size) {
paged_kv_t<PageStorage::kPointer, T, int32_t> paged_kv(
num_layers, layer_idx, num_kv_heads, page_size, head_dim, batch_size,
kv_data, kv_indicies, kv_indptr, last_page_offset);
kv_ptrs, kv_indptr, last_page_offset);

constexpr size_t vec_size =
std::max(16 / sizeof(T), static_cast<size_t>(head_dim / 32));
constexpr size_t bdx = head_dim / vec_size;
constexpr size_t bdy = 128 / bdx;
constexpr size_t bdy = 1;
dim3 nblks(paged_kv.batch_size * paged_kv.num_heads / bdy);
dim3 nthrs(bdx, bdy);
flashinfer::AppendPagedKVCacheDecodeKernel<head_dim, vec_size, bdx, bdy,
PageStorage::kIndices, T, int32_t>
PageStorage::kPointer, T, int32_t>
<<<nblks, nthrs>>>(paged_kv, key, value);
}

#define INST_FlashInferBatchDecodeKernel(T) \
template void FlashInferBatchDecodeKernel<T>( \
T * o, T * q, T * kv_data, int32_t * kv_indptr, int32_t * kv_indicies, \
int32_t * last_page_offset, int head_dim, int num_layers, int layer_idx, \
int num_qo_heads, int num_kv_heads, int page_size, int batch_size);
#define INST_FlashInferBatchDecodeKernel(T) \
template void FlashInferBatchDecodeKernel<T>( \
T * o, T * q, T * *kv_ptrs, int32_t * kv_indptr, \
int32_t * last_page_offset, void* kv_aux, int head_dim, int num_layers, \
int layer_idx, int num_qo_heads, int num_kv_heads, int page_size, \
int batch_size);

INST_FlashInferBatchDecodeKernel(nv_half);
INST_FlashInferBatchDecodeKernel(nv_bfloat16);

#define INST_FlashInferInitKvKernel(head_dim, T) \
template void FlashInferInitKvKernel<head_dim, T>( \
T * kv_data, int32_t * kv_indptr, int32_t * kv_indicies, \
int32_t * last_page_offset, T * key, T * value, int32_t * seqlen_indptr, \
int num_layers, int layer_idx, int num_kv_heads, int page_size, \
int batch_size);
#define INST_FlashInferInitKvKernel(head_dim, T) \
template void FlashInferInitKvKernel<head_dim, T>( \
T * *kv_ptrs, int32_t * kv_indptr, int32_t * last_page_offset, T * key, \
T * value, int32_t * seqlen_indptr, int num_layers, int layer_idx, \
int num_kv_heads, int page_size, int batch_size);

FOR_FlashInferBatchDecode_D(INST_FlashInferInitKvKernel, nv_half);
FOR_FlashInferBatchDecode_D(INST_FlashInferInitKvKernel, nv_bfloat16);

#define INST_FlashInferAppendKvKernel(head_dim, T) \
template void FlashInferAppendKvKernel<head_dim, T>( \
T * kv_data, int32_t * kv_indptr, int32_t * kv_indicies, \
int32_t * last_page_offset, T * key, T * value, int num_layers, \
int layer_idx, int num_kv_heads, int page_size, int batch_size);
#define INST_FlashInferAppendKvKernel(head_dim, T) \
template void FlashInferAppendKvKernel<head_dim, T>( \
T * *kv_ptrs, int32_t * kv_indptr, int32_t * last_page_offset, T * key, \
T * value, int num_layers, int layer_idx, int num_kv_heads, \
int page_size, int batch_size);
FOR_FlashInferBatchDecode_D(INST_FlashInferAppendKvKernel, nv_half);
FOR_FlashInferBatchDecode_D(INST_FlashInferAppendKvKernel, nv_bfloat16);
25 changes: 12 additions & 13 deletions csrc/flashinfer_adapter/flashinfer_config.h
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
#pragma once

template <typename T>
void FlashInferBatchDecodeKernel(T* o, T* q, T* kv_data, int32_t* kv_indptr,
int32_t* kv_indicies,
int32_t* last_page_offset, int head_dim,
int num_layers, int layer_idx,
void FlashInferBatchDecodeKernel(T* o, T* q, T** kv_ptrs, int32_t* kv_indptr,
int32_t* last_page_offset, void* kv_aux,
int head_dim, int num_layers, int layer_idx,
int num_qo_heads, int num_kv_heads,
int page_size, int batch_size);

template <int head_dim, typename T>
void FlashInferInitKvKernel(T* kv_data, int32_t* kv_indptr,
int32_t* kv_indicies, int32_t* last_page_offset,
T* key, T* value, int32_t* seqlen_indptr,
int num_layers, int layer_idx, int num_kv_heads,
int page_size, int batch_size);
void FlashInferInitKvKernel(T** kv_ptrs, int32_t* kv_indptr,
int32_t* last_page_offset, T* key, T* value,
int32_t* seqlen_indptr, int num_layers,
int layer_idx, int num_kv_heads, int page_size,
int batch_size);

template <int head_dim, typename T>
void FlashInferAppendKvKernel(T* kv_data, int32_t* kv_indptr,
int32_t* kv_indicies, int32_t* last_page_offset,
T* key, T* value, int num_layers, int layer_idx,
int num_kv_heads, int page_size, int batch_size);
void FlashInferAppendKvKernel(T** kv_ptrs, int32_t* kv_indptr,
int32_t* last_page_offset, T* key, T* value,
int num_layers, int layer_idx, int num_kv_heads,
int page_size, int batch_size);

// clang-format off

Expand Down
Loading

0 comments on commit a737045

Please sign in to comment.