Skip to content

Commit

Permalink
Fix two issues with rmsnorm.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Dec 21, 2023
1 parent ff47b59 commit d6045cf
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 5 deletions.
2 changes: 1 addition & 1 deletion lib/nnc/ccv_cnnp_model_addons.c
Original file line number Diff line number Diff line change
Expand Up @@ -2107,7 +2107,7 @@ static void _ccv_cnnp_rmsnorm_build(ccv_cnnp_model_t* const super, ccv_nnc_symbo
// Both scale and bias are shared between if this model is reused.
if (!self->scale.graph)
self->scale = ccv_nnc_tensor_symbol_new(graph, scale_params, "scale");
const ccv_nnc_cmd_t rmsnorm = ccv_nnc_cmd(CCV_NNC_LAYER_NORM_FORWARD, 0, self->params, 0);
const ccv_nnc_cmd_t rmsnorm = ccv_nnc_cmd(CCV_NNC_RMSNORM_FORWARD, 0, self->params, 0);
ccv_nnc_tensor_param_t output_params[2];
ccv_nnc_hint_tensor_auto(rmsnorm, (ccv_nnc_tensor_param_t []){
params,
Expand Down
4 changes: 0 additions & 4 deletions lib/nnc/cmd/norm/ccv_nnc_norm.c
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,6 @@ static int _ccv_nnc_rmsnorm_forw_bitmask(const ccv_nnc_cmd_param_t cmd, const in
// 2 outputs (y, saved_inv_std)
if (input_bitmasks[0] == 3u && output_bitmasks[0] == 3u)
return 1;
// 2 inputs (x, gamma)
// 1 output (y)
if (input_bitmasks[0] == 3u && output_bitmasks[0] == 1u)
return 1;
return 0;
}

Expand Down

0 comments on commit d6045cf

Please sign in to comment.