From 534d25c774dfda8f6a9ced6ac9bfc148f9c3aaa1 Mon Sep 17 00:00:00 2001 From: skykongkong8 <ss.kong@samsung.com> Date: Mon, 26 Feb 2024 14:14:51 +0900 Subject: [PATCH] [ Trivial ] Shorter code for Half-precision BN layer - For mixed precision training computation of BN layer, we have been declaring temporal single-precision Tensors, compute it with Tensor ops, and save the result in previous half-precision. - To remove redundant code, declare temporal function and reuse. - We should revisit here fore even clearer code when TensorV2 refactorization is finished **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: skykongkong8 <ss.kong@samsung.com> --- nntrainer/layers/bn_layer.cpp | 121 +++++++--------------------------- nntrainer/tensor/tensor.h | 14 ++++ 2 files changed, 36 insertions(+), 99 deletions(-) diff --git a/nntrainer/layers/bn_layer.cpp b/nntrainer/layers/bn_layer.cpp index e3c179d1f0..2c58b91eda 100644 --- a/nntrainer/layers/bn_layer.cpp +++ b/nntrainer/layers/bn_layer.cpp @@ -176,56 +176,17 @@ void BatchNormalizationLayer::forwarding(RunLayerContext &context, Tensor &cvar = context.getTensor(wt_idx[BNParams::cvar]); if (input_.getDataType() == ml::train::TensorDim::DataType::FP16) { #ifdef ENABLE_FP16 - TensorDim mu_dim = mu.getDim(); - mu_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor mu32(mu_dim, true); - mu32.copyData(mu); - - TensorDim var_dim = var.getDim(); - var_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor var32(var_dim, true); - var32.copyData(var); - - TensorDim gamma_dim = gamma.getDim(); - gamma_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor gamma32(gamma_dim, true); - gamma32.copyData(gamma); - - TensorDim beta_dim = beta.getDim(); - beta_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor beta32(beta_dim, true); - beta32.copyData(beta); - - TensorDim input_dim = input_.getDim(); - input_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor input_32(input_dim, true); - input_32.copyData(input_); - - TensorDim hidden_dim = hidden_.getDim(); - hidden_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor hidden_32(hidden_dim, true); - hidden_32.copyData(hidden_); - Tensor t_full32 = hidden_32; - - TensorDim deviation_dim = deviation.getDim(); - deviation_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor deviation32(deviation_dim, true); - deviation32.copyData(deviation); - - TensorDim dim_invstd = invstd.getDim(); - dim_invstd.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor invstd32(dim_invstd, true); - invstd32.copyData(invstd); - - TensorDim t_reduced_dim = t_reduced.getDim(); - t_reduced_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor t_reduced32(t_reduced_dim, true); - t_reduced32.copyData(t_reduced); - - TensorDim cvar_dim = cvar.getDim(); - cvar_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor cvar32(cvar_dim, true); - cvar32.copyData(cvar); + Tensor mu32 = mu.getSingleTensor(); + Tensor var32 = var.getSingleTensor(); + Tensor gamma32 = gamma.getSingleTensor(); + Tensor beta32 = beta.getSingleTensor(); + Tensor input_32 = input_.getSingleTensor(); + Tensor hidden_32 = hidden_.getSingleTensor(); + Tensor t_full32 = hidden_32; + Tensor deviation32 = deviation.getSingleTensor(); + Tensor invstd32 = invstd.getSingleTensor(); + Tensor t_reduced32 = t_reduced.getSingleTensor(); + Tensor cvar32 = cvar.getSingleTensor(); if (training) { input_32.average(axes_to_reduce, t_reduced32); @@ -308,45 +269,14 @@ void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) { Tensor &t_full = context.getTensor(wt_idx[BNParams::t_full]); if (deriv.getDataType() == ml::train::TensorDim::DataType::FP16) { #ifdef ENABLE_FP16 - TensorDim gamma_dim = gamma.getDim(); - gamma_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor gamma32(gamma_dim, true); - gamma32.copyData(gamma); - - TensorDim deriv_dim = deriv.getDim(); - deriv_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor deriv32(deriv_dim, true); - deriv32.copyData(deriv); - - TensorDim dx_dim = dx.getDim(); - dx_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor dx32(dx_dim, true); - dx32.copyData(dx); - - TensorDim deviation_dim = deviation.getDim(); - deviation_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor deviation32(deviation_dim, true); - deviation32.copyData(deviation); - - TensorDim invstd_dim = invstd.getDim(); - invstd_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor invstd32(invstd_dim, true); - invstd32.copyData(invstd); - - TensorDim cvar_dim = cvar.getDim(); - cvar_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor cvar32(cvar_dim, true); - cvar32.copyData(cvar); - - TensorDim t_reduced_dim = t_reduced.getDim(); - t_reduced_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor t_reduced32(t_reduced_dim, true); - t_reduced32.copyData(t_reduced); - - TensorDim t_full_dim = t_full.getDim(); - t_full_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor t_full32(t_full_dim, true); - t_full32.copyData(t_full); + Tensor gamma32 = gamma.getSingleTensor(); + Tensor deriv32 = deriv.getSingleTensor(); + Tensor dx32 = dx.getSingleTensor(); + Tensor deviation32 = deviation.getSingleTensor(); + Tensor invstd32 = invstd.getSingleTensor(); + Tensor cvar32 = cvar.getSingleTensor(); + Tensor t_reduced32 = t_reduced.getSingleTensor(); + Tensor t_full32 = t_full.getSingleTensor(); deviation32.multiply(deriv32, t_full32); t_full32.average(axes_to_reduce, t_reduced32); @@ -357,12 +287,8 @@ void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) { /** * This calculates dgamma tensor. */ - Tensor &dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]); - TensorDim dgamma_dim = dgamma.getDim(); - dgamma_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor dgamma32(dgamma_dim, true); - dgamma32.copyData(dgamma); - + Tensor dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]); + Tensor dgamma32 = dgamma.getSingleTensor(); t_full32.multiply_i(invstd32); t_full32.sum(axes_to_reduce, dgamma32); dgamma.copyData(dgamma32); @@ -371,10 +297,7 @@ void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) { * This implementation depends on the pre-calculated dbeta calculated. */ Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]); - TensorDim dbeta_dim = dbeta.getDim(); - dbeta_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor dbeta32(dbeta_dim, true); - dbeta32.copyData(dbeta); + Tensor dbeta32 = dbeta.getSingleTensor(); dbeta32.divide(divider, t_reduced32); } else { deriv32.average(axes_to_reduce, t_reduced32); diff --git a/nntrainer/tensor/tensor.h b/nntrainer/tensor/tensor.h index 211334da40..a8c5b1cdb6 100644 --- a/nntrainer/tensor/tensor.h +++ b/nntrainer/tensor/tensor.h @@ -2014,6 +2014,20 @@ class Tensor { scale_factors_fp16 = scales; } + + /** + * @brief Get the Single Tensor object + * + * @param input + * @return Tensor + */ + Tensor getSingleTensor() const { + TensorDim output_dim = getDim(); + output_dim.setDataType(ml::train::TensorDim::DataType::FP32); + Tensor output(output_dim, true); + output.copyData(*this); + return output; + } #endif /**