From e915517e338c51fbeff574eac2451abbf071c9ba Mon Sep 17 00:00:00 2001 From: zhaochaoxing <109726331+zhaochaoxing@users.noreply.github.com> Date: Wed, 12 Jun 2024 11:24:05 +0800 Subject: [PATCH] [ascend]Zcx/llama2 infer 910b (#1254) * optimize lightllm * fix promptFlashAttention on a+x * add check for incre flash attention * add description of the added funtion --- .github/workflows/_runs-on-nv-step1.yml | 4 +- diopi_test/python/configs/diopi_configs.py | 158 +++++++++ .../python/conformance/customized_test.py | 112 ++++++- .../python/conformance/diopi_functions.py | 306 ++++++++++++++++-- .../python/conformance/global_op_list.py | 135 ++++---- impl/ascend_npu/ascend_config.yaml | 5 + .../functions_ext/apply_penalty.cpp | 22 ++ .../context_attention_inference.cpp | 33 ++ .../functions_ext/destindex_copy_kv.cpp | 11 +- .../functions_ext/matmul_all_reduce.cpp | 19 ++ .../functions_ext/rotary_embedding.cpp | 17 + .../token_attention_inference.cpp | 48 +++ proto/include/diopi/functions_ext.h | 86 +++++ 13 files changed, 861 insertions(+), 95 deletions(-) create mode 100644 impl/ascend_npu/diopi_impl/functions_ext/matmul_all_reduce.cpp diff --git a/.github/workflows/_runs-on-nv-step1.yml b/.github/workflows/_runs-on-nv-step1.yml index 49e8718dd6..845288aa4b 100644 --- a/.github/workflows/_runs-on-nv-step1.yml +++ b/.github/workflows/_runs-on-nv-step1.yml @@ -78,7 +78,7 @@ jobs: && python main.py --mode gen_data" \ || ( cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1} && git clean -xdf ${GEN_DATA} && exit 1 ) source ~/Aoss_env.sh - ads-cli cp s3://${Platform_ci_aoss_name}:${Platform_ci_aoss_url}@platform.aoss.cn-sh-01c.sensecoreapi-oss.cn${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ + ads-cli cp ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ s3://${Platform_ci_aoss_name}:${Platform_ci_aoss_url}@platform.aoss.cn-sh-01c.sensecoreapi-oss.cn${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ elif [[ "${GETRUNNER}" == *diopi* ]];then ssh SH1424 """ set -e @@ -87,7 +87,7 @@ jobs: srun --job-name=${GITHUB_JOB} --partition=${SLURM_PAR_V100} --time=20 --gres=gpu:1 bash -c 'python main.py --mode gen_data' \ || ( cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1} && git clean -xdf ${GEN_DATA} && exit 1 ) source ~/Aoss_env.sh - ads-cli cp s3://${Platform_ci_aoss_name}:${Platform_ci_aoss_url}@platform.aoss.cn-sh-01c.sensecoreapi-oss.cn${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ + ads-cli cp ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ s3://${Platform_ci_aoss_name}:${Platform_ci_aoss_url}@platform.aoss.cn-sh-01c.sensecoreapi-oss.cn${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ """ else ln -s ${GEN_DATA_PATH}/${GEN_DATA}/diopi ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ diff --git a/diopi_test/python/configs/diopi_configs.py b/diopi_test/python/configs/diopi_configs.py index 4d9eb69e94..fcaae81a21 100755 --- a/diopi_test/python/configs/diopi_configs.py +++ b/diopi_test/python/configs/diopi_configs.py @@ -8305,6 +8305,36 @@ ), ), + 'rotary_emb_v2': dict( + name=['rotary_emb_v2'], + interface=['CustomizedTest'], + dtype=[np.float32, np.float16], + para=dict( + dim=[128,] + ), + tensor_para=dict( + gen_fn='Genfunc.randn', + args=[ + { + "ins": ['query'], + "shape": ((8, 4096),), + }, + { + "ins": ['key'], + "shape": ((8, 4096),), + }, + { + "ins": ['cos'], + "shape": ((8, 1, 128),), + }, + { + "ins": ['sin'], + "shape": ((8, 1, 128),), + }, + ], + ), + ), + 'rms_norm_default': dict( name=['rms_norm'], atol=1e-4, @@ -8551,6 +8581,134 @@ ), ), + 'prompt_flash_attention': dict( + name=['prompt_flash_attention'], + interface=['CustomizedTest'], + atol=1e-2, + rtol=1e-2, + para=dict( + maxInputLen=[2,], + actualSeqLengths=[[2,2],], + numHeads=[32,], + numKeyValueHeads=[32,], + dim=[128,], + ), + tensor_para=dict( + args=[ + { + "ins": ["query"], + "shape": ((4, 4096),), + "dtype": [np.float16,], + }, + { + "ins": ["key"], + "shape": ((4, 4096),), + "dtype": [np.float16,], + }, + { + "ins": ["value"], + "shape": ((4, 4096),), + "dtype": [np.float16,], + }, + { + "ins": ["attenMask"], + "value": ([[[False, True], + [False, False]], + [[False, True], + [False, False]]],), + "dtype": [np.bool_,], + "gen_policy": "gen_tensor_by_value" + }, + ] + ), + ), + + 'paged_attention': dict( + name=['paged_attention'], + interface=['CustomizedTest'], + atol=1e-2, + rtol=1e-2, + para=dict( + actualSeqLengths=[[150,],], + numHeads=[32,], + numKeyValueHeads=[32,], + dim=[128,], + blockSize=[128,], + ), + tensor_para=dict( + args=[ + { + "ins": ["query"], + "shape": ((1, 4096),), + "dtype": [np.float16,], + }, + { + "ins": ["key"], + "shape": ((1026, 4096),), + "dtype": [np.float16,], + }, + { + "ins": ["value"], + "shape": ((1026, 4096),), + "dtype": [np.float16,], + }, + { + "ins": ["blockTable"], + "value": ([[0, 1],],), + "dtype": [np.int32,], + "gen_policy": "gen_tensor_by_value" + }, + ] + ), + ), + + 'apply_penalty_v2': dict( + name=['apply_penalty_v2'], + interface=['CustomizedTest'], + tensor_para=dict( + args=[ + { + "ins": ['logits'], + "value": ([[0.1, 0.5, 0.4, 0.3, 0.5], + [0.2, 0.4, 0.0, 0.0, 0.0], + [0.3, 0.4, 0.5, 0.3, 0.0]],), + "dtype": [np.float16, np.float32], + "gen_policy": "gen_tensor_by_value" + }, + { + "ins": ["presence_penalty"], + "value": ([0.1, 0.1, 0.1, 0.1, 0.8, 0.8, 0.8, 0.8, 1.0, 1.0, 1.0],), + "dtype": [np.float16, np.float32], + "gen_policy": "gen_tensor_by_value" + }, + { + "ins": ["frequency_penalty"], + "value": ([0.1, 0.1, 0.1, 0.1, 0.8, 0.8, 0.8, 0.8, 1.0, 1.0, 1.0],), + "dtype": [np.float16, np.float32], + "gen_policy": "gen_tensor_by_value" + }, + { + "ins": ["repetition_penalty"], + "value": ([0.1, 0.1, 0.1, 0.1, 0.8, 0.8, 0.8, 0.8, 1.0, 1.0, 1.0],), + "dtype": [np.float16, np.float32], + "gen_policy": "gen_tensor_by_value" + }, + { + "ins": ["p_token_ids"], + "value": ([0, 1, 2, 3, 4, 5, 6, 8, 9, 10, 11],), + "dtype": [np.int32, np.int32], + "gen_policy": "gen_tensor_by_value" + }, + { + "ins": ["p_token_counts"], + "value": ([3, 3, 2, 2, 1, 3, 3, 3, 3, 2, 2],), + "dtype": [np.int32, np.int32], + "gen_policy": "gen_tensor_by_value" + }, + ] + ) + ), + 'token_attention': dict( name=['token_attention'], interface=['CustomizedTest'], diff --git a/diopi_test/python/conformance/customized_test.py b/diopi_test/python/conformance/customized_test.py index a51342a2c2..02913403d5 100644 --- a/diopi_test/python/conformance/customized_test.py +++ b/diopi_test/python/conformance/customized_test.py @@ -626,6 +626,108 @@ def context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len): ) return out + def prompt_flash_attention( + query, + key, + value, + attenMask, + actualSeqLengths, + maxInputLen, + numHeads, + numKeyValueHeads, + dim, + ): + bs = len(actualSeqLengths) + xq = query.view(bs, maxInputLen, numHeads, dim).cuda() + keys = key.view(bs, maxInputLen, numKeyValueHeads, dim).cuda() + values = value.view(bs, maxInputLen, numKeyValueHeads, dim).cuda() + mask = ( + torch.tril(torch.ones(maxInputLen, maxInputLen), diagonal=0) + .unsqueeze(0) + .unsqueeze(0) + .cuda() + ) + mask = mask.masked_fill(mask == 0.0, -100000000.0) + mask = mask.repeat(bs, numHeads, 1, 1) + xq = xq.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(dim) + scores = F.softmax(scores.float() + mask, dim=-1).type_as(xq) + out = torch.matmul(scores, values).transpose(1, 2).contiguous() + return out.reshape(bs * maxInputLen, numHeads * dim) + + def paged_attention( + query, + key, + value, + actualSeqLengths, + numHeads, + numKeyValueHeads, + dim, + blockTable, + blockSize, + ): + # q: BSH + b_loc = torch.arange(key.shape[0], dtype=torch.int32).reshape(1, -1).cuda() + batch = b_loc.shape[0] + xq = query.view(batch, 1, numHeads, dim).transpose(1, 2).cuda() + k = key.view(-1, numKeyValueHeads, dim).cuda() + v = value.view(-1, numKeyValueHeads, dim).cuda() + out = torch.empty([batch, numHeads, dim], device="cuda", dtype=query.dtype) + max_input_len = max(actualSeqLengths) + b_seq_len = torch.tensor(actualSeqLengths, dtype=torch.int32).cuda() + for i in range(batch): + k_loc = b_loc[i][ + max_input_len + - b_seq_len[i] + + torch.arange(0, b_seq_len[i], device="cuda", dtype=torch.int32) + ] + key = k[k_loc, :].view(1, b_seq_len[i], numHeads, dim).transpose(1, 2) + logics = ( + torch.matmul(xq[i, :], key.transpose(2, 3)) / math.sqrt(dim) + ).reshape(numHeads, b_seq_len[i]) + v_loc = b_loc[i][ + max_input_len + - b_seq_len[i] + + torch.arange(0, b_seq_len[i], device=logics.device, dtype=torch.int32) + ] + P = logics.softmax(-1).reshape(1, numHeads, 1, b_seq_len[i]) + V = v[v_loc, :].view(1, b_seq_len[i], numHeads, dim).transpose(1, 2) + out[i, :] = torch.matmul(P, V).view(numHeads, dim) + return out.view(-1, numHeads * dim) + + def apply_penalty_v2( + logits, + presence_penalty, + frequency_penalty, + repetition_penalty, + p_token_ids, + p_token_counts, + ): + batch = logits.shape[0] + logits = logits.view(-1) + cur_logits = logits.index_select(0, p_token_ids) + rep_logits = torch.where( + cur_logits > 0, + cur_logits / repetition_penalty, + cur_logits * repetition_penalty, + ) + rep_logits = rep_logits - p_token_counts * frequency_penalty - presence_penalty + logits[p_token_ids] = rep_logits + return logits.view(batch, -1) + + def rotary_emb_v2(query, key, cos, sin, dim): + query = query.view(query.shape[0], -1, dim) + key = key.view(key.shape[0], -1, dim) + q1, q2 = query.chunk(2, dim=-1) + query_rotate = torch.cat((-q2, q1), dim=-1) + query = query * cos + query_rotate * sin + k1, k2 = key.chunk(2, dim=-1) + key_rotate = torch.cat((-k2, k1), dim=-1) + key = key * cos + key_rotate * sin + return query.view(query.shape[0], -1), key.view(key.shape[0], -1) + def attention( query, key, @@ -742,5 +844,13 @@ def attention_varlen( start_idx = cu_seqlens_q[i] end_idx = cu_seqlens_q[i + 1] actual_seq_len = end_idx - start_idx - out[start_idx:end_idx, :, :] = out_paded[i, :actual_seq_len, :, :] # BSND->TND + out[start_idx:end_idx, :, :] = out_paded[ + i, :actual_seq_len, :, : + ] # BSND->TND + return out + + def nll_loss_v2(input, target, weight=None, ignore_index=-100, reduction="mean"): + out = torch.nn.functional.nll_loss( + input, target, weight, None, ignore_index, None, reduction + ) return out diff --git a/diopi_test/python/conformance/diopi_functions.py b/diopi_test/python/conformance/diopi_functions.py index 5b617a3447..af450fdf1d 100644 --- a/diopi_test/python/conformance/diopi_functions.py +++ b/diopi_test/python/conformance/diopi_functions.py @@ -2004,6 +2004,40 @@ def nll_loss(input, target, weight=None, ignore_index=-100, reduction="mean"): return out +def nll_loss_v2(input, target, weight=None, ignore_index=-100, reduction="mean"): + assert reduction in [ + "mean", + "sum", + "none", + ], "reduction must be one of (mean, sum, none)" + + if weight is not None: + assert isinstance(weight, Tensor), "weigth must be a Tensor" + + if reduction == "none": + out = Tensor(target.size().data, input.get_dtype()) + else: + out = Tensor((), input.get_dtype()) + + totalWeight = Tensor((1,), input.get_dtype()) + + reduction_mode = convert_reduction(reduction) + func = check_function("diopiNLLLossV2") + ret = func( + input.context(), + out, + totalWeight, + input, + target, + weight, + reduction_mode, + ignore_index, + ) + check_returncode(ret) + GLOBAL_STATE["nll_loss_v2_totalWeight"] = totalWeight + return out + + def sigmoid_focal_loss( inputs, targets, alpha=0.25, gamma=2, reduction="none" ) -> Tensor: @@ -2799,6 +2833,40 @@ def nll_loss_backward( return {"input": grad_input} +def nll_loss_v2_backward( + input, + grad_outputs, + target, + weight=None, + ignore_index=-100, + reduction="mean", + **kwargs, +) -> Tensor: + assert len(grad_outputs) == 1, "only accept 1 gradient to do backward" + grad_input = raw_like(input) + + if weight is not None: + assert isinstance(weight, Tensor), "weigth must be a Tensor" + + reduction_mode = convert_reduction(reduction) + + totalWeight = GLOBAL_STATE.pop("nll_loss_v2_totalWeight") + func = check_function("diopiNLLLossV2Backward") + ret = func( + input.context(), + grad_input, + grad_outputs[0], + input, + target, + weight, + totalWeight, + reduction_mode, + ignore_index, + ) + check_returncode(ret) + return {"input": grad_input} + + def max_pool2d_backward( input, grad_outputs, @@ -5086,9 +5154,20 @@ def rms_norm_backward(grad_outputs, input, weight, bias, normalized_shape, eps): # If not specified, normalized_shape generally defaults to the size of the last dimension of the input tensor. normalized_shape = Sizes(normalized_shape) - inv_rms = GLOBAL_STATE.pop('rms_norm_inv_rms') - ret = func(input.context(), grad_input, grad_weight, grad_bias, grad_outputs[0], input, weight, bias, inv_rms, - normalized_shape, eps) + inv_rms = GLOBAL_STATE.pop("rms_norm_inv_rms") + ret = func( + input.context(), + grad_input, + grad_weight, + grad_bias, + grad_outputs[0], + input, + weight, + bias, + inv_rms, + normalized_shape, + eps, + ) check_returncode(ret) if bias is None: return {"input": grad_input, "weight": grad_weight} @@ -5441,23 +5520,55 @@ def flash_attention_v3(q, k, v, p_dropout, softmax_scale, is_causal): GLOBAL_STATE["flash_attention_v3_generator"] = generator return out -def flash_attention_v3_backward(q, k, v, out, grad_outputs, p_dropout, softmax_scale, is_causal): + +def flash_attention_v3_backward( + q, k, v, out, grad_outputs, p_dropout, softmax_scale, is_causal +): call = "diopiFlashAttentionV3Backward" func = check_function(call) - assert p_dropout >=0 and p_dropout <=1, "The p_dropout value must be in range of [0, 1]" + assert ( + p_dropout >= 0 and p_dropout <= 1 + ), "The p_dropout value must be in range of [0, 1]" grad_q = raw_like(q) grad_k = raw_like(k) grad_v = raw_like(v) q_size = list(q.size().data) head_dim = q_size[-1] - softmax_lse = GLOBAL_STATE.pop('flash_attention_v3_softmax_lse') - generator = GLOBAL_STATE.pop('flash_attention_v3_generator') + softmax_lse = GLOBAL_STATE.pop("flash_attention_v3_softmax_lse") + generator = GLOBAL_STATE.pop("flash_attention_v3_generator") softmax_scale = 1.0 / math.sqrt(head_dim) if not softmax_scale else softmax_scale - ret = func(q.context(), grad_q, grad_k, grad_v, grad_outputs[0], generator, q, k, v, out, softmax_lse, p_dropout, softmax_scale, is_causal) + ret = func( + q.context(), + grad_q, + grad_k, + grad_v, + grad_outputs[0], + generator, + q, + k, + v, + out, + softmax_lse, + p_dropout, + softmax_scale, + is_causal, + ) check_returncode(ret) - return {'q': grad_q, 'k': grad_k, 'v': grad_v} + return {"q": grad_q, "k": grad_k, "v": grad_v} + -def flash_attention_varlen(q, k, v, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, p_dropout, softmax_scale, is_causal): +def flash_attention_varlen( + q, + k, + v, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + p_dropout, + softmax_scale, + is_causal, +): call = "diopiFlashAttentionVarLen" func = check_function(call) q_size = list(q.size().data) @@ -5485,7 +5596,9 @@ def flash_attention_varlen(q, k, v, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, c softmax_max_ptr = TensorP(softmax_max) softmax_sum_ptr = TensorP(softmax_sum) softmax_out_ptr = TensorP(softmax_out) - softmax_scale = 1.0 / math.sqrt(q.shape().data[-1]) if not softmax_scale else softmax_scale + softmax_scale = ( + 1.0 / math.sqrt(q.shape().data[-1]) if not softmax_scale else softmax_scale + ) ret = func( q.context(), out, @@ -5514,10 +5627,26 @@ def flash_attention_varlen(q, k, v, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, c GLOBAL_STATE["flash_attention_varlen_softmax_out"] = softmax_out return out -def flash_attention_varlen_backward(q, k, v, out, grad_outputs, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, p_dropout, softmax_scale, is_causal): + +def flash_attention_varlen_backward( + q, + k, + v, + out, + grad_outputs, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + p_dropout, + softmax_scale, + is_causal, +): call = "diopiFlashAttentionVarLenBackward" func = check_function(call) - assert p_dropout >=0 and p_dropout <=1, "The p_dropout value must be in range of [0, 1]" + assert ( + p_dropout >= 0 and p_dropout <= 1 + ), "The p_dropout value must be in range of [0, 1]" head_dim = q.shape().data[-1] softmax_scale = 1.0 / math.sqrt(head_dim) if not softmax_scale else softmax_scale cu_seqlens_q = Sizes(cu_seqlens_q[1:]) @@ -5555,7 +5684,17 @@ def flash_attention_varlen_backward(q, k, v, out, grad_outputs, max_seqlen_q, ma check_returncode(ret) return out -def attention(query, key, value, attn_mask=None, attn_bias=None, dropout_p=0.0, is_causal=False, scale=None): + +def attention( + query, + key, + value, + attn_mask=None, + attn_bias=None, + dropout_p=0.0, + is_causal=False, + scale=None, +): func = check_function("diopiAttention") attn_out = raw_like(query) max_tensor_num_for_backward = 16 @@ -5565,8 +5704,21 @@ def attention(query, key, value, attn_mask=None, attn_bias=None, dropout_p=0.0, generator = Generator(build_generator_state(query.context())) if scale is None: scale = 1.0 / math.sqrt(query.size().data[-1]) - ret = func(query.context(), attn_out, save_for_backward, get_capsule(byref(save_tensor_num)), - query, key, value, attn_mask, attn_bias, dropout_p, generator, scale, is_causal) + ret = func( + query.context(), + attn_out, + save_for_backward, + get_capsule(byref(save_tensor_num)), + query, + key, + value, + attn_mask, + attn_bias, + dropout_p, + generator, + scale, + is_causal, + ) check_returncode(ret) save_for_backward_tensor_list = [] for i in range(save_tensor_num.value): @@ -5579,15 +5731,7 @@ def attention(query, key, value, attn_mask=None, attn_bias=None, dropout_p=0.0, def attention_backward( - query, - key, - value, - out, - grad_outputs, - dropout_p, - scale, - is_causal, - attn_bias + query, key, value, out, grad_outputs, dropout_p, scale, is_causal, attn_bias ): call = "diopiAttentionBackward" func = check_function(call) @@ -5622,7 +5766,7 @@ def attention_backward( save_tensor_num, dropout_p, generator, - softmax_scale + softmax_scale, ) check_returncode(ret) return {"query": grad_q, "key": grad_k, "value": grad_v} @@ -5707,7 +5851,7 @@ def attention_varlen_backward( grad_v = raw_like(value) grad_attn_bias = None if attn_bias is not None: - #grad_attn_bias = raw_like(attn_bias) + # grad_attn_bias = raw_like(attn_bias) pass save_tensor_num = GLOBAL_STATE.pop("attention_save_tensor_num") save_for_backward_tensor_list = GLOBAL_STATE.pop("save_for_backward_tensor_list") @@ -5733,7 +5877,7 @@ def attention_varlen_backward( save_tensor_num, dropout_p, generator, - softmax_scale + softmax_scale, ) check_returncode(ret) return {"query": grad_q, "key": grad_k, "value": grad_v} @@ -5874,3 +6018,109 @@ def context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len): ret = func(q.context(), out, q, k, v, b_start_loc, b_seq_len, max_input_len) check_returncode(ret) return out + + +def prompt_flash_attention( + query, + key, + value, + attenMask, + actualSeqLengths, + maxInputLen, + numHeads, + numKeyValueHeads, + dim, +): + call = "diopiPromptFlashAttention" + func = check_function(call) + actualSeqLengths = Sizes(actualSeqLengths) + out = raw_like(query) + + ret = func( + query.context(), + out, + query, + key, + value, + attenMask, + actualSeqLengths, + maxInputLen, + numHeads, + numKeyValueHeads, + dim, + ) + check_returncode(ret) + return out + + +def paged_attention( + query, + key, + value, + actualSeqLengths, + numHeads, + numKeyValueHeads, + dim, + blockTable, + blockSize, +): + call = "diopiPagedAttention" + func = check_function(call) + actualSeqLengths = Sizes(actualSeqLengths) + out = raw_like(query) + + ret = func( + query.context(), + out, + query, + key, + value, + actualSeqLengths, + numHeads, + numKeyValueHeads, + dim, + blockTable, + blockSize, + ) + check_returncode(ret) + return out + + +def apply_penalty_v2( + logits, + presence_penalty, + frequency_penalty, + repetition_penalty, + p_token_ids, + p_token_counts, +): + call = "diopiApplyPenaltyV2" + func = check_function(call) + # some checks + p_token_ids_shape = list(p_token_ids.size().data) + p_token_counts_shape = list(p_token_counts.size().data) + + assert ( + p_token_ids_shape == p_token_counts_shape + ), "The shape of p_token_ids must be equal to the shape of p_token_counts." + + ret = func( + logits.context(), + logits, + presence_penalty, + frequency_penalty, + repetition_penalty, + p_token_ids, + p_token_counts, + ) + out = logits + check_returncode(ret) + return out + + +def rotary_emb_v2(query, key, cos, sin, dim): + call = "diopiRotaryEmbeddingV2" + func = check_function(call) + ret = func(query.context(), query, key, cos, sin, dim) + check_returncode(ret) + return query, key diff --git a/diopi_test/python/conformance/global_op_list.py b/diopi_test/python/conformance/global_op_list.py index 0ec189b6db..7ac5f61263 100644 --- a/diopi_test/python/conformance/global_op_list.py +++ b/diopi_test/python/conformance/global_op_list.py @@ -7,75 +7,84 @@ # 2. For camb test, adaptive_max_pool2d/max_pool2d need indices being int32 # Only conv2d, bn, adaptive_avg_pool2d, adaptive_max_pool2d can be tested, because # the rest have't been implemented. -nhwc_op = {'conv2d': ["2d", "input", 'weight'], - 'conv3d': ["3d", "input", 'weight'], - 'batch_norm': ['input'], - 'adaptive_avg_pool2d': ["2d", 'input'], - 'adaptive_max_pool2d': ["2d", 'input'], - 'adaptive_avg_pool3d': ["3d", 'input'], - 'adaptive_max_pool3d': ["3d", 'input'], - 'avg_pool2d': ["2d", 'input'], - 'max_pool2d': ["2d", 'input'], - # 'avg_pool3d': ["3d", 'input'], diopi doesn't hava avg_pool3d test - 'max_pool3d': ["3d", 'input'], - # both embedding - 'interpolate': ['input'], - 'pad': ['input'], - 'roi_align': ['input']} +nhwc_op = { + "conv2d": ["2d", "input", "weight"], + "conv3d": ["3d", "input", "weight"], + "batch_norm": ["input"], + "adaptive_avg_pool2d": ["2d", "input"], + "adaptive_max_pool2d": ["2d", "input"], + "adaptive_avg_pool3d": ["3d", "input"], + "adaptive_max_pool3d": ["3d", "input"], + "avg_pool2d": ["2d", "input"], + "max_pool2d": ["2d", "input"], + # 'avg_pool3d': ["3d", 'input'], diopi doesn't hava avg_pool3d test + "max_pool3d": ["3d", "input"], + # both embedding + "interpolate": ["input"], + "pad": ["input"], + "roi_align": ["input"], +} # Note : 1. camb test: all ops implemented is passed. # 2. nv test: most of ops is not implemented for 'Int'. # Tests of index_select, bce, embedding passed for 'Int'. -dtype_op = {'nll_loss': ['target'], # input using int32/float32 type - 'cross_entropy': ['target'], - 'index_select': ['index'], - 'index_put': ['indices1', 'indices2'], - 'binary_cross_entropy_with_logits': ['pos_weight'], - 'gather': ['index'], - 'scatter': ['index'], - 'embedding': ['input'], - 'index': ['idx1', 'idx2'], - 'ctc_loss': ['targets', 'input_lengths', 'target_lengths'], - 'index_fill': ['index'], - 'one_hot': ['input']} +dtype_op = { + "nll_loss": ["target"], # input using int32/float32 type + "cross_entropy": ["target"], + "index_select": ["index"], + "index_put": ["indices1", "indices2"], + "binary_cross_entropy_with_logits": ["pos_weight"], + "gather": ["index"], + "scatter": ["index"], + "embedding": ["input"], + "index": ["idx1", "idx2"], + "ctc_loss": ["targets", "input_lengths", "target_lengths"], + "index_fill": ["index"], + "one_hot": ["input"], +} # Note : 1. camb test: all ops implemented is passed. # 2. nv test: most of ops is not implemented for 'Int'. # Tests of unique, arange, randperm, argmax passed for 'Int'. -dtype_out_op = {'max_pool2d': ['indices'], # out using int32/float32 type - 'max_pool3d': ['indices'], - 'adaptive_max_pool2d': ['indices'], - 'adaptive_max_pool3d': ['indices'], - 'max': ['indices'], - 'min': ['indices'], - 'sort': ['indices'], - 'topk': ['indices'], - 'unique': ['indices'], - 'one_hot': ['out'], - 'arange': ['out'], - 'randperm': ['out'], - 'argmax': ['out']} +dtype_out_op = { + "max_pool2d": ["indices"], # out using int32/float32 type + "max_pool3d": ["indices"], + "adaptive_max_pool2d": ["indices"], + "adaptive_max_pool3d": ["indices"], + "max": ["indices"], + "min": ["indices"], + "sort": ["indices"], + "topk": ["indices"], + "unique": ["indices"], + "one_hot": ["out"], + "arange": ["out"], + "randperm": ["out"], + "argmax": ["out"], +} -ops_with_states = {"batch_norm": {"running_mean", "running_var"}, - "sgd": {"buf", "param"}, - "fill_": {"input"}, - "zero_": {"input"}, - "embedding": {"weight"}, - "adam": {"param", "exp_avg", "exp_avg_sq", "max_exp_avg_sq"}, - "adamw": {"param", "exp_avg", "exp_avg_sq", "max_exp_avg_sq"}, - "adadelta": {"param", "square_avg", "acc_delta"}, - "rmsprop": {"param", "square_avg", "grad_avg", "momentum_buffer"}, - "copy_": {"input"}, - "cast_dtype": {"out"}, - "batch_norm_gather_stats_with_counts": {"running_mean", "running_var"}, - "clip_grad_norm_": {"tensors"}, - "apply_penalty": {"logits"}, - "context_attention": {"out"}, - "destindex_copy_kv": {"out"}, - "token_attention": {"out"}, - "token_softmax_reducev": {"out"}, - "random": {"input"}, - "uniform": {"input"}, - "normal_": {"input"}, - "bernoulli": {"input"}, # compared in manual_test - } +ops_with_states = { + "batch_norm": {"running_mean", "running_var"}, + "sgd": {"buf", "param"}, + "fill_": {"input"}, + "zero_": {"input"}, + "embedding": {"weight"}, + "adam": {"param", "exp_avg", "exp_avg_sq", "max_exp_avg_sq"}, + "adamw": {"param", "exp_avg", "exp_avg_sq", "max_exp_avg_sq"}, + "adadelta": {"param", "square_avg", "acc_delta"}, + "rmsprop": {"param", "square_avg", "grad_avg", "momentum_buffer"}, + "copy_": {"input"}, + "cast_dtype": {"out"}, + "batch_norm_gather_stats_with_counts": {"running_mean", "running_var"}, + "clip_grad_norm_": {"tensors"}, + "apply_penalty": {"logits"}, + "context_attention": {"out"}, + "apply_penalty_v2": {"logits"}, + "rotary_emb_v2": {"query", "key"}, + "destindex_copy_kv": {"out"}, + "token_attention": {"out"}, + "token_softmax_reducev": {"out"}, + "random": {"input"}, + "uniform": {"input"}, + "normal_": {"input"}, + "bernoulli": {"input"}, # compared in manual_test +} diff --git a/impl/ascend_npu/ascend_config.yaml b/impl/ascend_npu/ascend_config.yaml index ceb71e244b..9c465653db 100755 --- a/impl/ascend_npu/ascend_config.yaml +++ b/impl/ascend_npu/ascend_config.yaml @@ -237,7 +237,9 @@ ascend_npu: - diopiTokenAttentionInference - diopiTokenSoftmaxReduceVInference - diopiApplyPenalty +- diopiApplyPenaltyV2 - diopiContextAttentionInference +- diopiPromptFlashAttention - diopiNativeMemoryFormatCast - diopiGetNativeMemoryFormat - diopiTensorDestructionHook @@ -272,3 +274,6 @@ ascend_npu: - diopiAttentionBackward - diopiAttentionVarLen - diopiAttentionVarLenBackward +- diopiPagedAttention +- diopiRotaryEmbeddingV2 +- diopiMatmulAllReduce diff --git a/impl/ascend_npu/diopi_impl/functions_ext/apply_penalty.cpp b/impl/ascend_npu/diopi_impl/functions_ext/apply_penalty.cpp index d1ec6177ce..b012c0752c 100644 --- a/impl/ascend_npu/diopi_impl/functions_ext/apply_penalty.cpp +++ b/impl/ascend_npu/diopi_impl/functions_ext/apply_penalty.cpp @@ -31,4 +31,26 @@ diopiError_t diopiApplyPenalty(diopiContextHandle_t ctx, diopiTensorHandle_t log } END_CALL_ACL_OP(); } + +diopiError_t diopiApplyPenaltyV2(diopiContextHandle_t ctx, diopiTensorHandle_t logits, diopiConstTensorHandle_t presencePenalty, + diopiConstTensorHandle_t frequencyPenalty, diopiConstTensorHandle_t repetitionPenalty, diopiConstTensorHandle_t pTokenIds, + diopiConstTensorHandle_t pTokenCounts) { + BEGIN_CALL_ACL_OP(logits, presencePenalty, frequencyPenalty, repetitionPenalty, pTokenIds, pTokenCounts); + logitsAt = impl::aten::viewStorage(logitsAt, {logitsAt.numel()}); + at::Tensor curLogits = op_api::index_select(logitsAt, 0, pTokenIdsAt); + at::Tensor repoLogits = at_npu::native::OpPreparation::apply_tensor_without_format(curLogits); + at::Tensor zero = at_npu::native::OpPreparation::apply_tensor_without_format(curLogits); + op_api::zero_(zero); + at::Tensor cand = at_npu::native::OpPreparation::apply_tensor_without_format(curLogits); + op_api::gt_out(curLogits, zero, cand); + op_api::where_out(cand, curLogits / repetitionPenaltyAt, curLogits * repetitionPenaltyAt, repoLogits); + repoLogits = repoLogits - pTokenCountsAt * frequencyPenaltyAt - presencePenaltyAt; + std::vector shape(pTokenIdsAt.dim() + 1, 1); + for (int64_t i = 0; i < pTokenIdsAt.dim(); i++) { + shape[i] = pTokenIdsAt.size(i); + } + pTokenIdsAt = impl::aten::viewStorage(pTokenIdsAt, shape); + EXEC_NPU_CMD(aclnnScatterNd, logitsAt, pTokenIdsAt, repoLogits, logitsAt); + END_CALL_ACL_OP(); +} } // namespace OP_IMPL_NS diff --git a/impl/ascend_npu/diopi_impl/functions_ext/context_attention_inference.cpp b/impl/ascend_npu/diopi_impl/functions_ext/context_attention_inference.cpp index 17a462f657..25b133d952 100644 --- a/impl/ascend_npu/diopi_impl/functions_ext/context_attention_inference.cpp +++ b/impl/ascend_npu/diopi_impl/functions_ext/context_attention_inference.cpp @@ -47,4 +47,37 @@ diopiError_t diopiContextAttentionInference(diopiContextHandle_t ctx, diopiTenso END_CALL_ACL_OP(); } +diopiError_t diopiPromptFlashAttention(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t query, diopiConstTensorHandle_t key, + diopiConstTensorHandle_t value, diopiConstTensorHandle_t attenMask, diopiSize_t actualSeqLengths, int64_t maxInputLen, + int64_t numHeads, int64_t numKeyValueHeads, int64_t dim) { + BEGIN_CALL_ACL_OP(out, query, key, value, attenMask); + at::IntArrayRef actSeqLen(actualSeqLengths.data, actualSeqLengths.len); + if (queryAt.dim() == 2) { + queryAt = impl::aten::viewStorage(queryAt, {actualSeqLengths.len, maxInputLen, queryAt.size(1)}); + outAt = impl::aten::viewStorage(outAt, {actualSeqLengths.len, maxInputLen, outAt.size(1)}); + keyAt = impl::aten::viewStorage(keyAt, {actualSeqLengths.len, maxInputLen, keyAt.size(1)}); + valueAt = impl::aten::viewStorage(valueAt, {actualSeqLengths.len, maxInputLen, valueAt.size(1)}); + } + double scaleValue = 1 / std::sqrt(dim); + int64_t preTokens = 2147473647; + int64_t nextTokens = 0; + at::Tensor paddingMask; + EXEC_NPU_NO_FORMAT_CHECK_CMD(aclnnPromptFlashAttention, + queryAt, + keyAt, + valueAt, + paddingMask, + attenMaskAt, + actSeqLen, + numHeads, + scaleValue, + preTokens, + nextTokens, + "BSH", + numKeyValueHeads, + outAt); + + END_CALL_ACL_OP(); +} + } // namespace OP_IMPL_NS diff --git a/impl/ascend_npu/diopi_impl/functions_ext/destindex_copy_kv.cpp b/impl/ascend_npu/diopi_impl/functions_ext/destindex_copy_kv.cpp index 71ae6a9090..d3a368dd89 100644 --- a/impl/ascend_npu/diopi_impl/functions_ext/destindex_copy_kv.cpp +++ b/impl/ascend_npu/diopi_impl/functions_ext/destindex_copy_kv.cpp @@ -12,7 +12,16 @@ namespace OP_IMPL_NS { diopiError_t diopiDestIndexCopyKV(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t k, diopiConstTensorHandle_t destLoc) { BEGIN_CALL_ACL_OP(out, k, destLoc); - at::index_put_(outAt, {destLocAt}, kAt, false); + if (destLocAt.sizes().size() != 1) { + set_last_error_string("only support destLoc.rank == 1"); + return diopiNoImplement; + } + std::vector shape(destLocAt.dim() + 1, 1); + for (int64_t i = 0; i < destLocAt.dim(); i++) { + shape[i] = destLocAt.size(i); + } + auto destLocReshapeAt = impl::aten::viewStorage(destLocAt, shape); + EXEC_NPU_CMD(aclnnScatterNd, outAt, destLocReshapeAt, kAt, outAt); END_CALL_ACL_OP(); } } // namespace OP_IMPL_NS diff --git a/impl/ascend_npu/diopi_impl/functions_ext/matmul_all_reduce.cpp b/impl/ascend_npu/diopi_impl/functions_ext/matmul_all_reduce.cpp new file mode 100644 index 0000000000..f6fa4a4561 --- /dev/null +++ b/impl/ascend_npu/diopi_impl/functions_ext/matmul_all_reduce.cpp @@ -0,0 +1,19 @@ +/** + * @file + * @author DeepLink + * @copyright (c) 2024, DeepLink. + */ + +#include "../helper.hpp" +#include "op_plugin/utils/op_api_common.h" + +namespace OP_IMPL_NS { + +diopiError_t diopiMatmulAllReduce(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t x1, diopiConstTensorHandle_t x2, + diopiConstTensorHandle_t bias, const char* group, const char* reduceOp, int64_t commTurn, int64_t streamMode) { + BEGIN_CALL_ACL_OP(out, x1, x2, bias); + EXEC_NPU_NO_FORMAT_CHECK_CMD(aclnnMatmulAllReduce, x1At, x2At, biasAt, group, reduceOp, commTurn, streamMode, outAt); + END_CALL_ACL_OP(); +} + +} // namespace OP_IMPL_NS diff --git a/impl/ascend_npu/diopi_impl/functions_ext/rotary_embedding.cpp b/impl/ascend_npu/diopi_impl/functions_ext/rotary_embedding.cpp index 4534293d78..a29508819f 100644 --- a/impl/ascend_npu/diopi_impl/functions_ext/rotary_embedding.cpp +++ b/impl/ascend_npu/diopi_impl/functions_ext/rotary_embedding.cpp @@ -9,6 +9,7 @@ #include "../helper.hpp" #include "op_plugin/AclOpsInterface.h" #include "op_plugin/OpApiInterface.h" +#include "op_plugin/utils/op_api_common.h" namespace OP_IMPL_NS { @@ -73,4 +74,20 @@ DIOPI_API diopiError_t diopiRotaryEmbedding(diopiContextHandle_t ctx, diopiTenso END_CALL_ACL_OP(); } +DIOPI_API diopiError_t diopiRotaryEmbeddingV2(diopiContextHandle_t ctx, diopiTensorHandle_t query, diopiTensorHandle_t key, diopiConstTensorHandle_t cos, + diopiConstTensorHandle_t sin, int64_t dim) { + BEGIN_CALL_ACL_OP(query, key, cos, sin); + int64_t layOut = 1; + if (queryAt.dim() == 2) { + queryAt = impl::aten::viewStorage(queryAt, {1, queryAt.size(0), queryAt.size(1) / dim, dim}); + } + if (keyAt.dim() == 2) { + keyAt = impl::aten::viewStorage(keyAt, {1, keyAt.size(0), keyAt.size(1) / dim, dim}); + } + cosAt = viewAs4D(cosAt); + sinAt = viewAs4D(sinAt); + EXEC_NPU_CMD(aclnnApplyRotaryPosEmb, queryAt, keyAt, cosAt, sinAt, layOut); + END_CALL_ACL_OP(); +} + } // namespace OP_IMPL_NS diff --git a/impl/ascend_npu/diopi_impl/functions_ext/token_attention_inference.cpp b/impl/ascend_npu/diopi_impl/functions_ext/token_attention_inference.cpp index c694ef0ff8..903f21a023 100644 --- a/impl/ascend_npu/diopi_impl/functions_ext/token_attention_inference.cpp +++ b/impl/ascend_npu/diopi_impl/functions_ext/token_attention_inference.cpp @@ -38,4 +38,52 @@ diopiError_t diopiTokenAttentionInference(diopiContextHandle_t ctx, diopiTensorH END_CALL_ACL_OP(); } +diopiError_t diopiPagedAttention(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t q, diopiConstTensorHandle_t k, + diopiConstTensorHandle_t v, diopiSize_t actualSeqLengths, int64_t numHeads, int64_t numKeyValueHeads, int64_t dim, + diopiConstTensorHandle_t blockTable, int64_t blockSize) { + BEGIN_CALL_ACL_OP(out, q, k, v, blockTable); + at::IntArrayRef actSeqLen(actualSeqLengths.data, actualSeqLengths.len); + TORCH_CHECK(actualSeqLengths.len == qAt.size(0), "The size of the first dimension of q must be equal to the length of actualSeqLengths!"); + TORCH_CHECK(actualSeqLengths.len == outAt.size(0), "The size of the first dimension of out must be equal to the length of actualSeqLengths!"); + if (qAt.dim() == 2) { + qAt = impl::aten::viewStorage(qAt, {qAt.size(0), (int64_t)1, qAt.size(1)}); + outAt = impl::aten::viewStorage(outAt, {outAt.size(0), (int64_t)1, outAt.size(1)}); + kAt = impl::aten::viewStorage(kAt, {kAt.size(0), (int64_t)1, kAt.size(1)}); + vAt = impl::aten::viewStorage(vAt, {vAt.size(0), (int64_t)1, vAt.size(1)}); + } + if (qAt.dim() == 3) { + TORCH_CHECK(1 == qAt.size(1), "The size of the second dimension of q must be 1!"); + TORCH_CHECK(1 == outAt.size(1), "The size of the second dimension of out must be 1!"); + } + double scaleValue = 1 / std::sqrt(dim); + at::TensorList keyTensors = kAt; + at::TensorList valueTensors = vAt; + int64_t innerPrecise = 1; + at::Tensor paddingMask, attenMask, dequantScale1, quantScale1, dequantScale2, quantScale2, quantOffset2, antiquantScale, antiquantOffset, kvPaddingSize; + EXEC_NPU_NO_FORMAT_CHECK_CMD(aclnnIncreFlashAttentionV4, + qAt, + keyTensors, + valueTensors, + paddingMask, + attenMask, + actSeqLen, + dequantScale1, + quantScale1, + dequantScale2, + quantScale2, + quantOffset2, + antiquantScale, + antiquantOffset, + blockTableAt, + kvPaddingSize, + numHeads, + scaleValue, + "BSH", + numKeyValueHeads, + blockSize, + innerPrecise, + outAt); + END_CALL_ACL_OP() +} + } // namespace OP_IMPL_NS diff --git a/proto/include/diopi/functions_ext.h b/proto/include/diopi/functions_ext.h index a5889f63d3..87bed2f35a 100644 --- a/proto/include/diopi/functions_ext.h +++ b/proto/include/diopi/functions_ext.h @@ -561,6 +561,26 @@ DIOPI_API diopiError_t diopiApplyPenalty(diopiContextHandle_t ctx, diopiTensorHa diopiConstTensorHandle_t frequency_penalty, diopiConstTensorHandle_t p_token_ids, diopiConstTensorHandle_t p_token_counts, diopiConstTensorHandle_t p_cumsum_seq_len, int p_max_len_in_batch); +/** + * @brief This function applies a penalty to the given logits based on the presence and frequency of certain tokens in the input sequence to suppress + * generating tokens repeatedly. + * For each token,the final logit = logit - corresponding_presence_penalty * token_counts - corresponding_presence_penalty. + * @param[in] ctx The diopi context. + * @param[inout] logits Tensor representing the logits. Shape: [batch_size, voc_len]. It contains the predicted scores for each token in the input sequences. + * It will be penalized by frequency_penalty and presence_penalty. + * @param[in] presence_penalty Tensor representing the presence penalty for each batch. Shape: [batch_size,]. It contains the penalty values to be subtracted + * from the logits. + * @param[in] frequency_penalty Tensor representing the frequency penalty for each batch. Shape: [batch_size,]. It contains the penalty values to be subtracted + * from the logits. + * @param[in] repetition_penalty Tensor representing the repetition penalty for each batch. Shape: [batch_size,]. It contains the penalty values to be + * subtracted from the logits. + * @param[in] p_token_ids Tensor representing the token_ids for generated tokens. Shape:[generated_tokens_num]. + * @param[in] p_token_counts Tensor representing the count of each token for generated tokens. Shape:[generated_tokens_num]. + */ +DIOPI_API diopiError_t diopiApplyPenaltyV2(diopiContextHandle_t ctx, diopiTensorHandle_t logits, diopiConstTensorHandle_t presence_penalty, + diopiConstTensorHandle_t frequency_penalty, diopiConstTensorHandle_t repetition_penalty, + diopiConstTensorHandle_t p_token_ids, diopiConstTensorHandle_t p_token_counts); + /** * @brief Copies the elements from k tensor into out tensor according to dest_loc tensor. It can be expressed in detail as: out[dest_loc] = k. During * model initialization, the KV cache is pre-allocated based on the user-set max_total_token_num and a Token Table is created to record the actual storage @@ -611,6 +631,23 @@ DIOPI_API diopiError_t diopiTokenSoftmaxReduceVInference(diopiContextHandle_t ct diopiConstTensorHandle_t v, diopiConstTensorHandle_t b_loc, diopiConstTensorHandle_t b_start_loc, diopiConstTensorHandle_t b_seq_len, int max_input_len, int other_kv_index); +/** + * @brief The implementation of pagedAttention, for more details please refer to https://blog.vllm.ai/2023/06/20/vllm.html + * @param[in] ctx diopi context. + * @param[in] out The output tensor of page attention operation. shape = [sum_batch_seq_len, head_num * head_dim] + * @param[in] q Tensor representing the query matrix in the attention mechanism. shape = [sum_batch_seq_len, head_num * head_dim]. + * @param[in] k Tensor representing the key matrix in the attention mechanism. shape = [sum_batch_seq_len, head_num * head_dim] + * @param[in] v Tensor representing the value matrix in the attention mechanism. shape = [sum_batch_seq_len, head_num * head_dim] + * @param[in] actual_seq_lengths Tensor representing the sequence length in each batch. shape = [batch_size] + * @param[in] num_heads head number of q and out. + * @param[in] num_kv_heads head number of key and value. + * @param[in] dim dimension of the transformer. + * @param[in] block_table Tensor representing the used blocks in each batch. shape = [batch_size, max_length_of_block_list] + * @param[in] block_size Size of eatch block unit. + */ +DIOPI_API diopiError_t diopiPagedAttention(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t q, diopiConstTensorHandle_t k, + diopiConstTensorHandle_t v, diopiSize_t actual_seq_lengths, int64_t num_heads, int64_t num_kv_heads, int64_t dim, + diopiConstTensorHandle_t block_table, int64_t block_size); /** * @brief The no pad implementation of * \text{context_attention_out}(\mathrm{q},\mathrm{k},\mathrm{v})=\text{softmax}(\frac{\mathrm{qk}^\mathrm{T}}{\sqrt{\mathrm{d_k}}})\mathrm{v}. For details, @@ -630,6 +667,55 @@ DIOPI_API diopiError_t diopiContextAttentionInference(diopiContextHandle_t ctx, diopiConstTensorHandle_t k, diopiConstTensorHandle_t v, diopiConstTensorHandle_t b_start_loc, diopiConstTensorHandle_t b_seq_len, int max_input_len); +/** + * @brief The no pad implementation of apply rotary embedding operation. For details, please refer to the official implementation using the triton kernel: + * https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rotary_emb.py + * @param[in] ctx The diopi context. + * @param[out] out The output tensor containing the rotary embeddings. type = [bfloat16, float16, float32, float64]. + * @param[in] query The query tensor which rotary embedding will be applied. type = [bfloat16, float16, float32, float64]. + * @param[in] key The key tensor which rotary embedding will be applied. type = [bfloat16, float16, float32, float64]. + * @param[in] cos The cosine values. type = [bfloat16, float16, float32, float64]. + * @param[in] sin The sine values. type = [bfloat16, float16, float32, float64]. + * @param[in] dim dimension of the transformer. + */ +DIOPI_API diopiError_t diopiRotaryEmbeddingV2(diopiContextHandle_t ctx, diopiTensorHandle_t query, diopiTensorHandle_t key, diopiConstTensorHandle_t cos, + diopiConstTensorHandle_t sin, int64_t dim); + +/** + * @brief The fused operation of Matmul and AllReduce. + * @param[in] ctx The diopi context. + * @param[out] out The output tensor of Matmul and AllReduce. + * @param[in] x1 The x1 tensor of matmul. type = [bfloat16, float16, float32, float64]. + * @param[in] x2 The x2 tensor of matmul. type = [bfloat16, float16, float32, float64]. + * @param[in] bias The bias tensor of matmul. type = [bfloat16, float16, float32, float64]. + * @param[in] group The group string of AllReduce. + * @param[in] reduceOp The reduce op string of AllReduce. + * @param[in] commTurn communication turn. + * @param[in] streamMode The stream mode for communication. + */ +DIOPI_API diopiError_t diopiMatmulAllReduce(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t x1, diopiConstTensorHandle_t x2, + diopiConstTensorHandle_t bias, const char* group, const char* reduce_op, int64_t comm_turn, int64_t stream_mode); + +/** + * @brief The no pad implementation of + * \text{context_attention_out}(\mathrm{q},\mathrm{k},\mathrm{v})=\text{softmax}(\frac{\mathrm{qk}^\mathrm{T}}{\sqrt{\mathrm{d_k}}})\mathrm{v}. For details, + * please refer to the official implementation using the triton kernel: + * https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py. + * @param[in] ctx diopi context. + * @param[in] out The output tensor of prompt flash attention operation. shape = [sum_batch_seq_len, head_num * head_dim] + * @param[in] query Tensor representing the query matrix in the attention mechanism. shape = [sum_batch_seq_len, head_num * head_dim]. + * @param[in] key Tensor representing the key matrix in the attention mechanism. shape = [sum_batch_seq_len, head_num * head_dim] + * @param[in] value Tensor representing the value matrix in the attention mechanism. shape = [sum_batch_seq_len, head_num * head_dim] + * @param[in] atten_mask Tensor representing the mask matrix in the attention mechanism. + * @param[in] actual_seq_lengths Tensor representing the sequence length in each batch. shape = [batch_size] + * @param[in] max_input_len The maximum length of all batch corresponding sequences. + * @param[in] num_heads head number of query and out. + * @param[in] num_kv_heads head number of key and value. + * @param[in] dim dimension of the transformer. + */ +DIOPI_API diopiError_t diopiPromptFlashAttention(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t query, + diopiConstTensorHandle_t key, diopiConstTensorHandle_t value, diopiConstTensorHandle_t atten_mask, + diopiSize_t actual_seq_lengths, int64_t max_input_len, int64_t num_heads, int64_t num_kv_heads, int64_t dim); // ============================================lightllm end======================================== #if defined(__cplusplus)