Skip to content

Commit

Permalink
[ROCm] Add attention kv cache for decoding (#16076)
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhan authored Jun 16, 2023
1 parent 9647149 commit 9110e5b
Show file tree
Hide file tree
Showing 9 changed files with 435 additions and 88 deletions.
6 changes: 6 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ Status LaunchConcatPastToPresent(cudaStream_t stream,
const half* past,
const half* k_v,
half* present);

template <typename T>
Status LaunchStridedCopy(cudaStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides, // coord (b,n,s,h)
int max_threads_per_block);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
158 changes: 158 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cuda/bert/attention_impl.h"

using namespace onnxruntime::cuda;

namespace onnxruntime {
namespace contrib {
namespace cuda {

template <typename T>
__global__ void StridedCopy(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides // coord (b,n,s,h)
) {
const int h = threadIdx.x;
const int n = threadIdx.y;
const int s = blockIdx.x;
const int b = blockIdx.y;
if (h < H) {
const int in_offset = b * in_strides.x + n * in_strides.y + s * in_strides.z + h * in_strides.w;
const int out_offset = b * out_strides.x + n * out_strides.y + s * out_strides.z + h * out_strides.w;
out[out_offset] = in[in_offset];
}
}

template <typename T>
__global__ void StridedCopyLarge(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides // coord (b,n,s,h)
) {
// Use when (H*)*num_heads > 1024
int h = threadIdx.x;
const int n = threadIdx.y;
const int s = blockIdx.x;
const int b = blockIdx.y;

const int h_step = blockDim.x;

while (h < H) {
const int in_offset = b * in_strides.x + n * in_strides.y + s * in_strides.z + h * in_strides.w;
const int out_offset = b * out_strides.x + n * out_strides.y + s * out_strides.z + h * out_strides.w;
out[out_offset] = in[in_offset];
h += h_step;
}
}

template <int NumBytes>
struct ToByteType;

template <>
struct ToByteType<2> {
using T = int16_t;
};

template <>
struct ToByteType<4> {
using T = int32_t;
};

template <>
struct ToByteType<8> {
using T = int64_t;
};

template <>
struct ToByteType<16> {
using T = uint4;
};

template <>
struct ToByteType<32> {
using T = ulonglong4;
};

template <int NumBytes>
using ToBytes = typename ToByteType<NumBytes>::T;

template <typename T>
Status LaunchStridedCopy(cudaStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides, // coord (b,n,s,h)
int max_threads_per_block) {
int batch_size = in_shape.x;
int num_heads = in_shape.y;
int sequence_length = in_shape.z;
int head_size = in_shape.w;
if (sequence_length == 0) {
return Status::OK();
}

const dim3 grid(sequence_length, batch_size);
if (0 == (head_size % 4)) { // pack 4 element together
using Bytes = ToBytes<sizeof(T) * 4>;
const int H = head_size / 4;
in_strides.x /= 4;
in_strides.y /= 4;
in_strides.z /= 4;
out_strides.x /= 4;
out_strides.y /= 4;
out_strides.z /= 4;
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
StridedCopy<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), H, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
StridedCopyLarge<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), H, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
}
} else if (0 == (head_size % 2)) { // pack 2 element together
using Bytes = ToBytes<sizeof(T) * 2>;
const int H = head_size / 2;
in_strides.x /= 2;
in_strides.y /= 2;
in_strides.z /= 2;
out_strides.x /= 2;
out_strides.y /= 2;
out_strides.z /= 2;
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
StridedCopy<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), H, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
StridedCopyLarge<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), H, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
}
} else {
using Bytes = ToBytes<sizeof(T)>;
if (head_size * num_heads <= max_threads_per_block) {
const dim3 block(head_size, num_heads, 1);
StridedCopy<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), head_size, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
StridedCopyLarge<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), head_size, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
}
}
return CUDA_CALL(cudaGetLastError());
}

template Status LaunchStridedCopy<float>(
cudaStream_t stream,
const float* in, int4 in_shape, longlong4 in_strides,
float* out, longlong4 out_strides,
int max_threads_per_block);

template Status LaunchStridedCopy<half>(
cudaStream_t stream,
const half* in, int4 in_shape, longlong4 in_strides,
half* out, longlong4 out_strides,
int max_threads_per_block);

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
61 changes: 41 additions & 20 deletions onnxruntime/contrib_ops/rocm/bert/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,11 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {

ORT_RETURN_IF_ERROR(ClassifyAttentionMode(
Node().OpType(), &attn, /*qkv=*/{}, /*past=*/{past}, /*present=*/{present}));
// TODO: support QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE and QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE
ORT_ENFORCE(attn.mode == QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE ||
attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE ||
attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE);
attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE ||
attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE ||
attn.mode == QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE);

size_t qkv_project_output_bytes = QkvProjectGeneric::GetOutputNumBytes(&attn);
size_t shared_workspace_bytes = std::max(QkvProjectGeneric::GetWorkspaceNumBytes(&attn),
Expand Down Expand Up @@ -123,27 +124,46 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
ORT_RETURN_IF_ERROR(QkvProjectGeneric::Run(&gemm_permute_params));
auto [q_buffer, k_buffer, v_buffer] = QkvProjectGeneric::UnspliceOutputQKV(&gemm_permute_params);

// NOTE: GemmPermute always output 3BNSH, k_buffer and v_buffer can be treated as 2BNSH
if (nullptr != present) {
// Concat past (2xBxNxS'xH) to present (2xBxNxTxH):
// past_k (BxNxS'xH) + k (BxNxSxH) => present_k (BxNxTxH)
// past_v (BxNxS'xH) + v (BxNxSxH) => present_v (BxNxTxH)
const int batches = attn.batch_size * attn.num_heads;
const int present_size_per_batch = attn.total_sequence_length * attn.head_size;
ORT_RETURN_IF_ERROR(
LaunchConcatPastToPresent(Stream(context),
attn.total_sequence_length,
attn.sequence_length,
attn.batch_size,
attn.head_size,
attn.num_heads,
device_prop.maxThreadsPerBlock,
nullptr == past ? nullptr : reinterpret_cast<const HipT*>(past->DataRaw()),
k_buffer,
reinterpret_cast<HipT*>(present->MutableDataRaw())));
Strides dst_strides; // the output buffer is present Tensor, the buffer is the same

int4 add_shape{2 * attn.batch_size, attn.num_heads, attn.sequence_length, attn.head_size};
HipT* add_dest = nullptr; // destination of concatenated data to present
const HipT* const add_src = k_buffer; // source of concatenated data to present
const auto add_src_strides = Strides::BNSHMemory(
2 * attn.batch_size, attn.num_heads, attn.sequence_length, attn.head_size);

if (attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE) {
dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size);
add_dest = reinterpret_cast<HipT*>(present->MutableDataRaw()) /* + dst_strides.OffsetAt(0, 0, 0, 0)*/;
} else if (attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE) {
dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size);
add_dest = reinterpret_cast<HipT*>(present->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0);

// We only need to copy past to present in this case. All other cases will be build the present incrementally
const int4 past_shape = {2 * attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size};
HipT* const past_dest = reinterpret_cast<HipT*>(present->MutableDataRaw());
const HipT* const past_src = reinterpret_cast<const HipT*>(past->DataRaw());
const Strides past_src_strides = Strides::BNSHMemory(
2 * attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size);

ORT_RETURN_IF_ERROR(LaunchStridedCopy(stream, past_src, past_shape, past_src_strides.ForBNSHCoord(),
past_dest, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock));
} else if (attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE) {
dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size);
add_dest = reinterpret_cast<HipT*>(present->MutableDataRaw()) /* + dst_strides.OffsetAt(0, 0, 0, 0)*/;
} else if (attn.mode == QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE) {
dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size);
add_dest = reinterpret_cast<HipT*>(present->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0);
}

ORT_RETURN_IF_ERROR(LaunchStridedCopy(stream, add_src, add_shape, add_src_strides.ForBNSHCoord(),
add_dest, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock));

// update pointers to present_k and present_v.
// update pointers to present_k and present_v. TODO: switch to ConvertToOffsetedBufferViews
k_buffer = reinterpret_cast<HipT*>(present->MutableDataRaw());
v_buffer = reinterpret_cast<HipT*>(present->MutableDataRaw()) + batches * present_size_per_batch;
v_buffer = reinterpret_cast<HipT*>(present->MutableDataRaw()) + dst_strides.OffsetAt(attn.batch_size, 0, 0, 0);
}

// For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax
Expand All @@ -160,6 +180,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
params.device_prop = &device_prop;
// FIXME: the params.scale seems to be different from AttentionParameters::scale;
params.scale = 1.0f / sqrt(static_cast<float>(attn.head_size));
// TODO: switch to ConvertToOffsetedBufferViews
params.q_buffer = q_buffer;
params.k_buffer = k_buffer;
params.v_buffer = v_buffer;
Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/contrib_ops/rocm/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,24 @@ Status ClassifyAttentionMode(
attn->mode = BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE;
return Status::OK();
}
} else if (num_qkv == 3 && num_past == 0 && num_present == 2) {
if (attn->past_present_share_buffer == false) {
if (attn->qkv_format == Q_K_V_BSNH) {
attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH;
return Status::OK();
} else if (attn->pass_past_in_kv) {
attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH;
return Status::OK();
}
} else {
if (attn->qkv_format == Q_K_V_BSNH) {
attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH;
return Status::OK();
} else if (attn->pass_past_in_kv) {
attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH;
return Status::OK();
}
}
} else if (num_qkv == 3 && num_past == 2 && num_present == 2) {
if (attn->past_present_share_buffer == false) {
if (attn->qkv_format == Q_K_V_BSNH) {
Expand Down
31 changes: 9 additions & 22 deletions onnxruntime/contrib_ops/rocm/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,28 +98,6 @@ Status LaunchConcatTensorToTensor(hipStream_t stream,
const half* tensor_add,
half* tensor_out);

Status LaunchConcatPastToPresent(hipStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const float* past,
const float* k_v,
float* present);

Status LaunchConcatPastToPresent(hipStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const half* past,
const half* k_v,
half* present);

inline rocblas_status _compat_rocblas_gemm_strided_batched_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
Expand Down Expand Up @@ -191,6 +169,10 @@ enum AttentionMode {
QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE,
BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE,
BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE,
BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH,
BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH,
BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH,
BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH,
BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH,
BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH,
BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH,
Expand All @@ -209,6 +191,11 @@ Status ClassifyAttentionMode(const std::string& op,
const std::vector<const Tensor*>& past,
const std::vector<Tensor*>& present);

template <typename T>
Status LaunchStridedCopy(hipStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides, // coord (b,n,s,h)
int max_threads_per_block);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
Loading

0 comments on commit 9110e5b

Please sign in to comment.