Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【LLM Inference-0】Add Split MoE Op && Add Group MoE #69687

Open
wants to merge 20 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5931,6 +5931,7 @@ void FusedMoeInferMeta(const MetaTensor& X,
const MetaTensor& ffn2_bias,
const std::string& quant_method,
const int moe_topk,
const bool group_moe,
const bool norm_topk_prob,
MetaTensor* out) {
out->set_dims(X.dims());
Expand Down Expand Up @@ -6019,6 +6020,73 @@ void MultiheadMatmulInferMeta(const MetaTensor& input,
out->share_lod(input);
}

void moe_dispatchInferMeta(const MetaTensor& X,
const MetaTensor& gating_output,
const int moe_topk,
const bool group_moe,
MetaTensor* permute_input,
MetaTensor* token_nums_per_expert,
MetaTensor* permute_indices_per_token,
MetaTensor* expert_scales_float,
MetaTensor* top_k_indices) {
int token_rows = 0;
auto input_dims = X.dims();
if (input_dims.size() == 3) {
token_rows = input_dims[0] * input_dims[1];
} else {
token_rows = input_dims[0];
}
const int num_rows = token_rows;
const int hidden_size = X.dims()[input_dims.size() - 1];

permute_input->set_dims({moe_topk * num_rows, hidden_size});
permute_input->set_dtype(X.dtype());
permute_input->set_layout(X.layout());

permute_indices_per_token->set_dims({moe_topk, num_rows});
permute_indices_per_token->set_dtype(DataType::INT32);
permute_indices_per_token->set_layout(X.layout());

expert_scales_float->set_dims({num_rows, moe_topk});
expert_scales_float->set_dtype(DataType::FLOAT32);
expert_scales_float->set_layout(X.layout());

top_k_indices->set_dims({num_rows, moe_topk});
top_k_indices->set_dtype(DataType::INT32);
top_k_indices->set_layout(X.layout());
}

void moe_ffnInferMeta(const MetaTensor& permute_input,
const MetaTensor& token_nums_per_expert,
const MetaTensor& ffn1_weight,
const MetaTensor& ffn2_weight,
const MetaTensor& ffn1_bias,
const MetaTensor& ffn1_scale,
const MetaTensor& ffn2_scale,
const std::string& quant_method,
MetaTensor* ffn_out) {
ffn_out->set_dims(permute_input.dims());
ffn_out->share_lod(permute_input);
ffn_out->set_dtype(permute_input.dtype());
ffn_out->set_layout(permute_input.layout());
}

void moe_reduceInferMeta(const MetaTensor& ffn_out,
const MetaTensor& expert_scales_float,
const MetaTensor& permute_indices_per_token,
const MetaTensor& top_k_indices,
const MetaTensor& ffn2_bias,
const bool norm_topk_prob,
MetaTensor* output) {
auto ffn_out_dims = ffn_out.dims();
const int top_k = top_k_indices.dims()[1];
const int num_rows = ffn_out_dims[0] / top_k;
const int hidden_size = ffn_out_dims[1];
output->set_dims({num_rows, hidden_size});
output->set_dtype(ffn_out.dtype());
output->set_layout(ffn_out.layout());
}

void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
const MetaTensor& cache_kv,
const MetaTensor& bias,
Expand Down
29 changes: 29 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -1176,8 +1176,37 @@ void FusedMoeInferMeta(const MetaTensor& X,
const std::string& quant_method,
const int moe_topk,
const bool norm_topk_prob,
const bool group_moe,
MetaTensor* out);

void moe_dispatchInferMeta(const MetaTensor& X,
const MetaTensor& gating_output,
const int moe_topk,
const bool group_moe,
MetaTensor* permute_input,
MetaTensor* token_nums_per_expert,
MetaTensor* permute_indices_per_token,
MetaTensor* expert_scales_float,
MetaTensor* top_k_indices);

void moe_ffnInferMeta(const MetaTensor& permute_input,
const MetaTensor& token_nums_per_expert,
const MetaTensor& ffn1_weight,
const MetaTensor& ffn2_weight,
const MetaTensor& ffn1_bias,
const MetaTensor& ffn1_scale,
const MetaTensor& ffn2_scale,
const std::string& quant_method,
MetaTensor* ffn_out);

void moe_reduceInferMeta(const MetaTensor& ffn_out,
const MetaTensor& expert_scales_float,
const MetaTensor& permute_indices_per_token,
const MetaTensor& top_k_indices,
const MetaTensor& ffn2_bias,
const bool norm_topk_prob,
MetaTensor* output);

void FusedMultiHeadAttentionInferMeta(const MetaTensor& query,
const MetaTensor& key,
const MetaTensor& value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ static std::vector<CutlassTileConfig> get_candidate_tiles(
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x64x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64,
CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64,
// CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64,
};
if (is_moe) {
quant_B_configs_sm80.push_back(
Expand Down Expand Up @@ -251,10 +251,10 @@ static CutlassGemmConfig estimate_best_config_from_occupancies(
[](const CutlassGemmConfig& gemm_config) {
return gemm_config.tile_config ==
CutlassTileConfig::
CtaShape128x256x64_WarpShape64x64x64;
CtaShape128x128x64_WarpShape128x32x64;
}) != candidate_configs.end()) {
best_config = CutlassGemmConfig{
CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64,
SplitKStyle::NO_SPLIT_K,
1,
5};
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/fusion/cutlass/fused_moe_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ void FusedMoeKernel(const Context& ctx,
const paddle::optional<DenseTensor>& ffn2_bias,
const std::string& quant_method,
const int moe_topk,
const bool group_moe,
const bool norm_topk_prob,
DenseTensor* out) {
out->Resize(X.dims());
Expand Down Expand Up @@ -84,6 +85,7 @@ void FusedMoeKernel(const Context& ctx,
ffn2_bias ? ffn2_bias.get_ptr() : nullptr,
nullptr,
moe_topk,
group_moe,
norm_topk_prob,
"ffn",
out);
Expand Down
79 changes: 48 additions & 31 deletions paddle/phi/kernels/fusion/cutlass/moe/fused_moe_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@ class MoeHelper {

// -------- getWorkspaceSize -------- //
template <typename KeyT>
size_t getWorkspaceSize(const int num_rows,
const int hidden_size,
const int inter_size,
const int num_experts,
const int k) {
const int buf_size = AlignTo16(k * num_rows * hidden_size);
const int interbuf_size = AlignTo16(k * num_rows * inter_size);
const int padded_experts = AlignTo16(num_experts);
const int num_moe_inputs = AlignTo16(k * num_rows);
size_t getWorkspaceSize(const int64_t num_rows,
const int64_t hidden_size,
const int64_t inter_size,
const int64_t num_experts,
const int64_t k) {
const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
const size_t padded_experts = AlignTo16(num_experts);
const size_t num_moe_inputs = AlignTo16(k * num_rows);
// softmax output, permuted_rows and permuted_experts have moved to outside
// of moe kernel, allocate them in Encoder or Decoder before invoking
// FfnLayer forward.
Expand All @@ -94,14 +94,14 @@ class MoeHelper {
total_ws_bytes +=
padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_

const int bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
const int sorter_ws_size_bytes =
const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
const size_t sorter_ws_size_bytes =
AlignTo16(sorter_.getWorkspaceSize(num_rows));
sorter_.update_num_experts(num_experts);

int bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
int64_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
if (sorter_ws_size_bytes > bytes_for_fc1_result) {
int remaining_bytes =
int64_t remaining_bytes =
AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result);
bytes_for_intermediate_and_sorting += remaining_bytes;
}
Expand All @@ -110,7 +110,7 @@ class MoeHelper {
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
// sorting workspace

int num_softmax_outs = 0;
int64_t num_softmax_outs = 0;
const bool is_pow_2 =
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
if (!is_pow_2 || num_experts > 256) {
Expand All @@ -132,6 +132,7 @@ class MoeHelper {
const DenseTensor *ffn2_bias,
const DenseTensor *moe_token_type_ids,
const int moe_topk,
const bool group_moe,
const bool norm_topk_prob,
const std::string moe_type,
DenseTensor *output) {
Expand All @@ -145,16 +146,16 @@ class MoeHelper {

auto input_dims = X->dims();
auto ffn1_dims = ffn1_weight->dims();
int token_num = 0;
int64_t token_num = 0;
if (input_dims.size() == 3) {
token_num = input_dims[0] * input_dims[1];
} else {
token_num = input_dims[0];
}
const int num_rows = token_num;
const int64_t num_rows = token_num;

const int hidden_size = ffn1_dims[1];
int inter_dim = 0;
const int64_t hidden_size = ffn1_dims[1];
int64_t inter_dim = 0;
if (moe_type == "qkv") {
inter_dim = ffn1_dims[2] * ffn1_dims[3] * ffn1_dims[4];
} else {
Expand All @@ -165,12 +166,17 @@ class MoeHelper {
inter_dim = inter_dim * 2;
}

const int inter_size = inter_dim;
const int num_experts = ffn1_dims[0];
const int k = moe_topk;
const int64_t inter_size = inter_dim;
const int64_t num_experts = ffn1_dims[0];
const int64_t k = moe_topk;

VLOG(4) << "num_rows: " << num_rows << " " << hidden_size << " "
<< inter_size << " " << num_experts << "k " << k;
VLOG(4) << "[MoE Info] "
<< "num_rows: " << num_rows << ", "
<< "hidden_size: " << hidden_size << ", "
<< "inter_size: " << inter_size << ", "
<< "num_experts: " << num_experts << ", "
<< "k: " << k << ", "
<< "group_moe: " << std::boolalpha << group_moe;

DenseTensor gate_tensor = Empty<float>(ctx, {num_rows, num_experts});
DenseTensor X_tensor = Empty<float>(ctx, {num_rows, hidden_size});
Expand All @@ -194,10 +200,10 @@ class MoeHelper {
DenseTensor ws_ptr_tensor = Empty<int8_t>(ctx, {bytes});
int8_t *ws_ptr = ws_ptr_tensor.data<int8_t>();

const int buf_size = AlignTo16(k * num_rows * hidden_size);
const int interbuf_size = AlignTo16(k * num_rows * inter_size);
const int padded_experts = AlignTo16(num_experts);
const int num_moe_inputs = AlignTo16(k * num_rows);
const int64_t buf_size = AlignTo16(k * num_rows * hidden_size);
const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size);
const int64_t padded_experts = AlignTo16(num_experts);
const int64_t num_moe_inputs = AlignTo16(k * num_rows);

expert_for_source_row = reinterpret_cast<int *>(ws_ptr);
source_rows_ = expert_for_source_row + num_moe_inputs;
Expand Down Expand Up @@ -229,10 +235,19 @@ class MoeHelper {
DenseTensor expert_scales_tensor_float = Empty<float>(ctx, {num_rows, k});
float *expert_scales_float = expert_scales_tensor_float.data<float>();

float *softmax_max_prob = nullptr;
if (group_moe) {
DenseTensor softmax_max_prob_tensor = Empty<float>(ctx, {num_rows, k});
softmax_max_prob = softmax_max_prob_tensor.data<float>();
funcs::SetConstant<GPUContext, float> zero_float;
zero_float(ctx, &softmax_max_prob_tensor, false);
}

DenseTensor fc1_out_tensor = Empty<T>(ctx, {num_rows * k, inter_size});
T *fc1_out = fc1_out_tensor.data<T>();

VLOG(4) << " gemm_method_ :" << gemm_method_;
VLOG(4) << " gemm method is :" << gemm_method_
<< ". group_moe is :" << group_moe;

DenseTensor mixgemm_workspace;
auto gate_compute = GEMMHelper<float>(
Expand Down Expand Up @@ -266,13 +281,15 @@ class MoeHelper {
softmax_out_,
expert_for_source_row,
source_rows_,
softmax_max_prob,
num_rows,
num_experts,
k,
group_moe,
ctx.stream());

const int sorter_ws_size_bytes =
AlignTo16(sorter_.getWorkspaceSize(k * num_rows));
const int64_t sorter_ws_size_bytes =
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));

sorter_.run(fc1_result_,
sorter_ws_size_bytes,
Expand All @@ -295,7 +312,7 @@ class MoeHelper {
k,
ctx.stream());

const int expanded_active_expert_rows = k * num_rows;
const int64_t expanded_active_expert_rows = k * num_rows;

compute_total_rows_before_expert<T>(permuted_experts_,
input_activations,
Expand Down
Loading
Loading