Skip to content

Commit

Permalink
compute position_ids use custom op
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome committed Jan 24, 2025
1 parent d76e357 commit a5a16d4
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 29 deletions.
69 changes: 69 additions & 0 deletions csrc/gpu/get_position_ids.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"
#include "paddle/extension.h"

__global__ void GetPositionIdsKernel(
const int* seq_lens_encoder, // [bsz] 每个批次的 encoder 长度
const int* seq_lens_decoder, // [bsz] 每个批次的 decoder 长度
int* position_ids, // 输出的一维 position_ids
const int bsz) { // 批次大小
// 当前线程索引(每个线程对应一个批次)
int tid = threadIdx.x;
if (tid >= bsz) return;

// 动态计算当前批次的偏移量
int offset = 0;
for (int i = 0; i < tid; i++) {
offset += seq_lens_encoder[i];
if (seq_lens_decoder[i] > 0) {
offset += 1;
}
}

// 当前批次的 encoder 和 decoder 长度
int encoder_len = seq_lens_encoder[tid];
int decoder_len = seq_lens_decoder[tid];

// 写入 encoder 的 position_ids
for (int i = 0; i < encoder_len; i++) {
position_ids[offset + i] = i;
}
offset += encoder_len;

// 写入 decoder 的 position_ids
if (decoder_len > 0) {
position_ids[offset] = decoder_len; // 使用 decoder 长度本身
}
}


void GetPositionIds(const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& position_ids) {
const int bsz = seq_lens_encoder.shape()[0];

GetPositionIdsKernel<<<1, bsz, 0, position_ids.stream()>>>(
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
const_cast<int*>(position_ids.data<int>()),
bsz);
}

PD_BUILD_OP(get_position_ids)
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "position_ids"})
.Outputs({"position_ids_out"})
.SetInplaceMap({{"position_ids", "position_ids_out"}})
.SetKernelFn(PD_KERNEL(GetPositionIds));
1 change: 1 addition & 0 deletions csrc/setup_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def get_gencode_flags():
"./gpu/step.cu",
"./gpu/quant_int8.cu",
"./gpu/dequant_int8.cu",
"./gpu/get_position_ids.cu",
"./gpu/flash_attn_bwd.cc",
"./gpu/tune_cublaslt_gemm.cu",
"./gpu/sample_kernels/top_p_sampling_reject.cu",
Expand Down
37 changes: 8 additions & 29 deletions paddlenlp/experimental/transformers/fused_transformer_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,35 +1298,14 @@ def compute_shared_expert(self, tmp_out, i):

def pre_process(self, **kwargs):
if self.config.mla_config.use_mla():
seq_lens_encoder = kwargs.get("seq_lens_encoder", None).cast("int64")
seq_lens_decoder = kwargs.get("seq_lens_decoder", None).cast("int64")
bsz = seq_lens_encoder.shape[0]

# 处理 Encoder 部分
max_len_encoder = seq_lens_encoder.max().item()
encoder_ids = paddle.arange(max_len_encoder).tile([bsz, 1]) # 每个批次生成完整索引
encoder_mask = paddle.arange(max_len_encoder).unsqueeze(0) < seq_lens_encoder # 根据 encoder 长度生成掩码
encoder_ids = paddle.masked_select(encoder_ids, encoder_mask) # 筛选有效的 Encoder 索引

# 生成批次索引用于保持顺序
encoder_batch_indices = paddle.repeat_interleave(
paddle.arange(bsz), seq_lens_encoder.squeeze(-1)
) # 每个样本的索引重复对应的长度

# 处理 Decoder 部分
decoder_mask = seq_lens_decoder > 0 # 筛选非零 decoder 长度
decoder_ids = paddle.masked_select(seq_lens_decoder, decoder_mask) # 提取非零 decoder 索引
decoder_batch_indices = paddle.masked_select(paddle.arange(bsz), decoder_mask.squeeze(-1)) # 提取有效的批次索引

# 合并 Encoder 和 Decoder
all_ids = paddle.concat([encoder_ids, decoder_ids])
all_batch_indices = paddle.concat([encoder_batch_indices, decoder_batch_indices])

# 根据批次索引排序,保证批次顺序
sorted_indices = paddle.argsort(all_batch_indices)
position_ids = paddle.gather(all_ids, sorted_indices)

self.position_ids = position_ids
seq_lens_encoder = kwargs.get("seq_lens_encoder", None)
seq_lens_decoder = kwargs.get("seq_lens_decoder", None)
position_ids_shape = paddle.sum(seq_lens_encoder) + paddle.sum(seq_lens_decoder > 0)
self.position_ids = paddle.zeros(shape=position_ids_shape, dtype=seq_lens_encoder.dtype)

from paddlenlp_ops import get_position_ids

get_position_ids(seq_lens_encoder, seq_lens_decoder, self.position_ids)

def post_process(self, **kwargs):
time_step = kwargs.get("time_step", None)
Expand Down

0 comments on commit a5a16d4

Please sign in to comment.