From 24cad9dbd0f68459f8dae29271b43519a0b79684 Mon Sep 17 00:00:00 2001 From: skykongkong8 Date: Tue, 6 Feb 2024 10:03:29 +0900 Subject: [PATCH] Mixed precision for bn layer - According to recent papers in mixed precision, when it comes to computation of statistics in bn layer, we should use fp32 values - input(fp16) <-> BN_layer(fp16 weight, but compute with fp32 and re-cast to fp16) <-> output(fp16) - Although we need a bulky code block for this, I believe we can revisit here for cleaner code when TensorV2 becomes official **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: skykongkong8 --- nntrainer/layers/bn_layer.cpp | 300 ++++++++++++++++++++++++++++------ 1 file changed, 251 insertions(+), 49 deletions(-) diff --git a/nntrainer/layers/bn_layer.cpp b/nntrainer/layers/bn_layer.cpp index 1723ac677f..e3c179d1f0 100644 --- a/nntrainer/layers/bn_layer.cpp +++ b/nntrainer/layers/bn_layer.cpp @@ -174,32 +174,125 @@ void BatchNormalizationLayer::forwarding(RunLayerContext &context, /** use hidden_ as temporary tensor before setting the result in hidden */ Tensor t_full = hidden_; Tensor &cvar = context.getTensor(wt_idx[BNParams::cvar]); + if (input_.getDataType() == ml::train::TensorDim::DataType::FP16) { +#ifdef ENABLE_FP16 + TensorDim mu_dim = mu.getDim(); + mu_dim.setDataType(ml::train::TensorDim::DataType::FP32); + Tensor mu32(mu_dim, true); + mu32.copyData(mu); + + TensorDim var_dim = var.getDim(); + var_dim.setDataType(ml::train::TensorDim::DataType::FP32); + Tensor var32(var_dim, true); + var32.copyData(var); + + TensorDim gamma_dim = gamma.getDim(); + gamma_dim.setDataType(ml::train::TensorDim::DataType::FP32); + Tensor gamma32(gamma_dim, true); + gamma32.copyData(gamma); + + TensorDim beta_dim = beta.getDim(); + beta_dim.setDataType(ml::train::TensorDim::DataType::FP32); + Tensor beta32(beta_dim, true); + beta32.copyData(beta); + + TensorDim input_dim = input_.getDim(); + input_dim.setDataType(ml::train::TensorDim::DataType::FP32); + Tensor input_32(input_dim, true); + input_32.copyData(input_); + + TensorDim hidden_dim = hidden_.getDim(); + hidden_dim.setDataType(ml::train::TensorDim::DataType::FP32); + Tensor hidden_32(hidden_dim, true); + hidden_32.copyData(hidden_); + Tensor t_full32 = hidden_32; + + TensorDim deviation_dim = deviation.getDim(); + deviation_dim.setDataType(ml::train::TensorDim::DataType::FP32); + Tensor deviation32(deviation_dim, true); + deviation32.copyData(deviation); + + TensorDim dim_invstd = invstd.getDim(); + dim_invstd.setDataType(ml::train::TensorDim::DataType::FP32); + Tensor invstd32(dim_invstd, true); + invstd32.copyData(invstd); + + TensorDim t_reduced_dim = t_reduced.getDim(); + t_reduced_dim.setDataType(ml::train::TensorDim::DataType::FP32); + Tensor t_reduced32(t_reduced_dim, true); + t_reduced32.copyData(t_reduced); + + TensorDim cvar_dim = cvar.getDim(); + cvar_dim.setDataType(ml::train::TensorDim::DataType::FP32); + Tensor cvar32(cvar_dim, true); + cvar32.copyData(cvar); + + if (training) { + input_32.average(axes_to_reduce, t_reduced32); + input_32.subtract(t_reduced32, deviation32); + + mu32.multiply_i(momentum); + mu32.add_i(t_reduced32, 1 - momentum); + + deviation32.pow(2.0f, t_full32); + t_full32.average(axes_to_reduce, cvar32); + + var32.multiply_i(momentum); + var32.add_i(cvar32, 1 - momentum); + + cvar32.add_i(epsilon); + cvar32.pow(-0.5f, invstd32); + } else { + input_32.subtract(mu32, deviation32); + /** @todo do below 2 lines only for first iteration */ + var32.add(epsilon, invstd32); + invstd32.pow_i(-0.5f); + } - if (training) { - input_.average(axes_to_reduce, t_reduced); - input_.subtract(t_reduced, deviation); - - mu.multiply_i(momentum); - mu.add_i(t_reduced, 1 - momentum); - - deviation.pow(2.0f, t_full); - t_full.average(axes_to_reduce, cvar); - - var.multiply_i(momentum); - var.add_i(cvar, 1 - momentum); - - cvar.add_i(epsilon); - cvar.pow(-0.5f, invstd); + deviation32.multiply(invstd32, hidden_32); + hidden_32.multiply_i(gamma32); + hidden_32.add_i(beta32); + + mu.copyData(mu32); + var.copyData(var32); + gamma.copyData(gamma32); + beta.copyData(beta32); + input_.copyData(input_32); + hidden_.copyData(hidden_32); + deviation.copyData(deviation32); + invstd.copyData(invstd32); + t_reduced.copyData(t_reduced32); + cvar.copyData(cvar32); +#else + throw std::runtime_error("enable-fp16 is not enabled"); +#endif } else { - input_.subtract(mu, deviation); - /** @todo do below 2 lines only for first iteration */ - var.add(epsilon, invstd); - invstd.pow_i(-0.5f); - } + if (training) { + input_.average(axes_to_reduce, t_reduced); + input_.subtract(t_reduced, deviation); + + mu.multiply_i(momentum); + mu.add_i(t_reduced, 1 - momentum); + + deviation.pow(2.0f, t_full); + t_full.average(axes_to_reduce, cvar); + + var.multiply_i(momentum); + var.add_i(cvar, 1 - momentum); + + cvar.add_i(epsilon); + cvar.pow(-0.5f, invstd); + } else { + input_.subtract(mu, deviation); + /** @todo do below 2 lines only for first iteration */ + var.add(epsilon, invstd); + invstd.pow_i(-0.5f); + } - deviation.multiply(invstd, hidden_); - hidden_.multiply_i(gamma); - hidden_.add_i(beta); + deviation.multiply(invstd, hidden_); + hidden_.multiply_i(gamma); + hidden_.add_i(beta); + } } void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) { @@ -213,42 +306,151 @@ void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) { Tensor &t_reduced = context.getTensor(wt_idx[BNParams::t_reduced]); Tensor &t_full = context.getTensor(wt_idx[BNParams::t_full]); + if (deriv.getDataType() == ml::train::TensorDim::DataType::FP16) { +#ifdef ENABLE_FP16 + TensorDim gamma_dim = gamma.getDim(); + gamma_dim.setDataType(ml::train::TensorDim::DataType::FP32); + Tensor gamma32(gamma_dim, true); + gamma32.copyData(gamma); + + TensorDim deriv_dim = deriv.getDim(); + deriv_dim.setDataType(ml::train::TensorDim::DataType::FP32); + Tensor deriv32(deriv_dim, true); + deriv32.copyData(deriv); + + TensorDim dx_dim = dx.getDim(); + dx_dim.setDataType(ml::train::TensorDim::DataType::FP32); + Tensor dx32(dx_dim, true); + dx32.copyData(dx); + + TensorDim deviation_dim = deviation.getDim(); + deviation_dim.setDataType(ml::train::TensorDim::DataType::FP32); + Tensor deviation32(deviation_dim, true); + deviation32.copyData(deviation); + + TensorDim invstd_dim = invstd.getDim(); + invstd_dim.setDataType(ml::train::TensorDim::DataType::FP32); + Tensor invstd32(invstd_dim, true); + invstd32.copyData(invstd); + + TensorDim cvar_dim = cvar.getDim(); + cvar_dim.setDataType(ml::train::TensorDim::DataType::FP32); + Tensor cvar32(cvar_dim, true); + cvar32.copyData(cvar); + + TensorDim t_reduced_dim = t_reduced.getDim(); + t_reduced_dim.setDataType(ml::train::TensorDim::DataType::FP32); + Tensor t_reduced32(t_reduced_dim, true); + t_reduced32.copyData(t_reduced); + + TensorDim t_full_dim = t_full.getDim(); + t_full_dim.setDataType(ml::train::TensorDim::DataType::FP32); + Tensor t_full32(t_full_dim, true); + t_full32.copyData(t_full); + + deviation32.multiply(deriv32, t_full32); + t_full32.average(axes_to_reduce, t_reduced32); + t_reduced32.divide_i(cvar32); + deviation32.multiply_i(t_reduced32); + + if (context.getTrainable()) { + /** + * This calculates dgamma tensor. + */ + Tensor &dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]); + TensorDim dgamma_dim = dgamma.getDim(); + dgamma_dim.setDataType(ml::train::TensorDim::DataType::FP32); + Tensor dgamma32(dgamma_dim, true); + dgamma32.copyData(dgamma); + + t_full32.multiply_i(invstd32); + t_full32.sum(axes_to_reduce, dgamma32); + dgamma.copyData(dgamma32); + + /** + * This implementation depends on the pre-calculated dbeta calculated. + */ + Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]); + TensorDim dbeta_dim = dbeta.getDim(); + dbeta_dim.setDataType(ml::train::TensorDim::DataType::FP32); + Tensor dbeta32(dbeta_dim, true); + dbeta32.copyData(dbeta); + dbeta32.divide(divider, t_reduced32); + } else { + deriv32.average(axes_to_reduce, t_reduced32); + } - deviation.multiply(deriv, t_full); - t_full.average(axes_to_reduce, t_reduced); - t_reduced.divide_i(cvar); - deviation.multiply_i(t_reduced); - - if (context.getTrainable()) { - /** - * This calculates dgamma tensor. - */ - Tensor &dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]); - t_full.multiply_i(invstd); - t_full.sum(axes_to_reduce, dgamma); - - /** - * This implementation depends on the pre-calculated dbeta calculated. - */ - Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]); - dbeta.divide(divider, t_reduced); + deriv32.subtract(t_reduced32, dx32); + dx32.subtract_i(deviation32); + + invstd32.multiply_i(gamma32); + dx32.multiply_i(invstd32); + + gamma.copyData(gamma32); + dx.copyData(dx32); + deviation.copyData(deviation32); + invstd.copyData(invstd32); + cvar.copyData(cvar32); + t_reduced.copyData(t_reduced32); + t_full.copyData(t_full32); +#else + throw std::runtime_error("enable-fp16 is not enabled"); +#endif } else { - deriv.average(axes_to_reduce, t_reduced); - } + deviation.multiply(deriv, t_full); + t_full.average(axes_to_reduce, t_reduced); + t_reduced.divide_i(cvar); + deviation.multiply_i(t_reduced); + + if (context.getTrainable()) { + /** + * This calculates dgamma tensor. + */ + Tensor &dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]); + t_full.multiply_i(invstd); + t_full.sum(axes_to_reduce, dgamma); + + /** + * This implementation depends on the pre-calculated dbeta calculated. + */ + Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]); + dbeta.divide(divider, t_reduced); + } else { + deriv.average(axes_to_reduce, t_reduced); + } - deriv.subtract(t_reduced, dx); - dx.subtract_i(deviation); + deriv.subtract(t_reduced, dx); + dx.subtract_i(deviation); - invstd.multiply_i(gamma); - dx.multiply_i(invstd); + invstd.multiply_i(gamma); + dx.multiply_i(invstd); + } } void BatchNormalizationLayer::calcGradient(RunLayerContext &context) { /** dgamma is calculated in calcDerivative. dbeta is calculated here */ Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]); const Tensor &deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX); - - deriv.sum(axes_to_reduce, dbeta); + if (deriv.getDataType() == ml::train::TensorDim::DataType::FP16) { +#ifdef ENABLE_FP16 + TensorDim dbeta_dim = dbeta.getDim(); + dbeta_dim.setDataType(ml::train::TensorDim::DataType::FP32); + Tensor dbeta32(dbeta_dim, true); + dbeta32.copyData(dbeta); + + TensorDim deriv_dim = deriv.getDim(); + deriv_dim.setDataType(ml::train::TensorDim::DataType::FP32); + Tensor deriv32(deriv_dim, true); + deriv32.copyData(deriv); + + deriv32.sum(axes_to_reduce, dbeta32); + dbeta.copyData(dbeta32); +#else + throw std::runtime_error("enable-fp16 is not enabled"); +#endif + } else { + deriv.sum(axes_to_reduce, dbeta); + } } void BatchNormalizationLayer::exportTo(