From 534d25c774dfda8f6a9ced6ac9bfc148f9c3aaa1 Mon Sep 17 00:00:00 2001
From: skykongkong8 <ss.kong@samsung.com>
Date: Mon, 26 Feb 2024 14:14:51 +0900
Subject: [PATCH] [ Trivial ] Shorter code for Half-precision BN layer

- For mixed precision training computation of BN layer, we have been declaring temporal single-precision Tensors, compute it with Tensor ops, and save the result in previous half-precision.
- To remove redundant code, declare temporal function and reuse.
- We should revisit here fore even clearer code when TensorV2 refactorization is finished

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
---
 nntrainer/layers/bn_layer.cpp | 121 +++++++---------------------------
 nntrainer/tensor/tensor.h     |  14 ++++
 2 files changed, 36 insertions(+), 99 deletions(-)

diff --git a/nntrainer/layers/bn_layer.cpp b/nntrainer/layers/bn_layer.cpp
index e3c179d1f0..2c58b91eda 100644
--- a/nntrainer/layers/bn_layer.cpp
+++ b/nntrainer/layers/bn_layer.cpp
@@ -176,56 +176,17 @@ void BatchNormalizationLayer::forwarding(RunLayerContext &context,
   Tensor &cvar = context.getTensor(wt_idx[BNParams::cvar]);
   if (input_.getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    TensorDim mu_dim = mu.getDim();
-    mu_dim.setDataType(ml::train::TensorDim::DataType::FP32);
-    Tensor mu32(mu_dim, true);
-    mu32.copyData(mu);
-
-    TensorDim var_dim = var.getDim();
-    var_dim.setDataType(ml::train::TensorDim::DataType::FP32);
-    Tensor var32(var_dim, true);
-    var32.copyData(var);
-
-    TensorDim gamma_dim = gamma.getDim();
-    gamma_dim.setDataType(ml::train::TensorDim::DataType::FP32);
-    Tensor gamma32(gamma_dim, true);
-    gamma32.copyData(gamma);
-
-    TensorDim beta_dim = beta.getDim();
-    beta_dim.setDataType(ml::train::TensorDim::DataType::FP32);
-    Tensor beta32(beta_dim, true);
-    beta32.copyData(beta);
-
-    TensorDim input_dim = input_.getDim();
-    input_dim.setDataType(ml::train::TensorDim::DataType::FP32);
-    Tensor input_32(input_dim, true);
-    input_32.copyData(input_);
-
-    TensorDim hidden_dim = hidden_.getDim();
-    hidden_dim.setDataType(ml::train::TensorDim::DataType::FP32);
-    Tensor hidden_32(hidden_dim, true);
-    hidden_32.copyData(hidden_);
-    Tensor t_full32 = hidden_32;
-
-    TensorDim deviation_dim = deviation.getDim();
-    deviation_dim.setDataType(ml::train::TensorDim::DataType::FP32);
-    Tensor deviation32(deviation_dim, true);
-    deviation32.copyData(deviation);
-
-    TensorDim dim_invstd = invstd.getDim();
-    dim_invstd.setDataType(ml::train::TensorDim::DataType::FP32);
-    Tensor invstd32(dim_invstd, true);
-    invstd32.copyData(invstd);
-
-    TensorDim t_reduced_dim = t_reduced.getDim();
-    t_reduced_dim.setDataType(ml::train::TensorDim::DataType::FP32);
-    Tensor t_reduced32(t_reduced_dim, true);
-    t_reduced32.copyData(t_reduced);
-
-    TensorDim cvar_dim = cvar.getDim();
-    cvar_dim.setDataType(ml::train::TensorDim::DataType::FP32);
-    Tensor cvar32(cvar_dim, true);
-    cvar32.copyData(cvar);
+    Tensor mu32 = mu.getSingleTensor();  
+    Tensor var32 = var.getSingleTensor();  
+    Tensor gamma32 = gamma.getSingleTensor();  
+    Tensor beta32 = beta.getSingleTensor();  
+    Tensor input_32 = input_.getSingleTensor();  
+    Tensor hidden_32 = hidden_.getSingleTensor();  
+    Tensor t_full32 = hidden_32;  
+    Tensor deviation32 = deviation.getSingleTensor();  
+    Tensor invstd32 = invstd.getSingleTensor();  
+    Tensor t_reduced32 = t_reduced.getSingleTensor();  
+    Tensor cvar32 = cvar.getSingleTensor();  
 
     if (training) {
       input_32.average(axes_to_reduce, t_reduced32);
@@ -308,45 +269,14 @@ void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) {
   Tensor &t_full = context.getTensor(wt_idx[BNParams::t_full]);
   if (deriv.getDataType() == ml::train::TensorDim::DataType::FP16) {
 #ifdef ENABLE_FP16
-    TensorDim gamma_dim = gamma.getDim();
-    gamma_dim.setDataType(ml::train::TensorDim::DataType::FP32);
-    Tensor gamma32(gamma_dim, true);
-    gamma32.copyData(gamma);
-
-    TensorDim deriv_dim = deriv.getDim();
-    deriv_dim.setDataType(ml::train::TensorDim::DataType::FP32);
-    Tensor deriv32(deriv_dim, true);
-    deriv32.copyData(deriv);
-
-    TensorDim dx_dim = dx.getDim();
-    dx_dim.setDataType(ml::train::TensorDim::DataType::FP32);
-    Tensor dx32(dx_dim, true);
-    dx32.copyData(dx);
-
-    TensorDim deviation_dim = deviation.getDim();
-    deviation_dim.setDataType(ml::train::TensorDim::DataType::FP32);
-    Tensor deviation32(deviation_dim, true);
-    deviation32.copyData(deviation);
-
-    TensorDim invstd_dim = invstd.getDim();
-    invstd_dim.setDataType(ml::train::TensorDim::DataType::FP32);
-    Tensor invstd32(invstd_dim, true);
-    invstd32.copyData(invstd);
-
-    TensorDim cvar_dim = cvar.getDim();
-    cvar_dim.setDataType(ml::train::TensorDim::DataType::FP32);
-    Tensor cvar32(cvar_dim, true);
-    cvar32.copyData(cvar);
-
-    TensorDim t_reduced_dim = t_reduced.getDim();
-    t_reduced_dim.setDataType(ml::train::TensorDim::DataType::FP32);
-    Tensor t_reduced32(t_reduced_dim, true);
-    t_reduced32.copyData(t_reduced);
-
-    TensorDim t_full_dim = t_full.getDim();
-    t_full_dim.setDataType(ml::train::TensorDim::DataType::FP32);
-    Tensor t_full32(t_full_dim, true);
-    t_full32.copyData(t_full);
+    Tensor gamma32 = gamma.getSingleTensor();  
+    Tensor deriv32 = deriv.getSingleTensor();  
+    Tensor dx32 = dx.getSingleTensor();  
+    Tensor deviation32 = deviation.getSingleTensor();  
+    Tensor invstd32 = invstd.getSingleTensor();  
+    Tensor cvar32 = cvar.getSingleTensor();  
+    Tensor t_reduced32 = t_reduced.getSingleTensor();  
+    Tensor t_full32 = t_full.getSingleTensor();  
 
     deviation32.multiply(deriv32, t_full32);
     t_full32.average(axes_to_reduce, t_reduced32);
@@ -357,12 +287,8 @@ void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) {
       /**
        * This calculates dgamma tensor.
        */
-      Tensor &dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]);
-      TensorDim dgamma_dim = dgamma.getDim();
-      dgamma_dim.setDataType(ml::train::TensorDim::DataType::FP32);
-      Tensor dgamma32(dgamma_dim, true);
-      dgamma32.copyData(dgamma);
-
+      Tensor dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]);
+      Tensor dgamma32 = dgamma.getSingleTensor();
       t_full32.multiply_i(invstd32);
       t_full32.sum(axes_to_reduce, dgamma32);
       dgamma.copyData(dgamma32);
@@ -371,10 +297,7 @@ void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) {
        * This implementation depends on the pre-calculated dbeta calculated.
        */
       Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]);
-      TensorDim dbeta_dim = dbeta.getDim();
-      dbeta_dim.setDataType(ml::train::TensorDim::DataType::FP32);
-      Tensor dbeta32(dbeta_dim, true);
-      dbeta32.copyData(dbeta);
+      Tensor dbeta32 = dbeta.getSingleTensor();
       dbeta32.divide(divider, t_reduced32);
     } else {
       deriv32.average(axes_to_reduce, t_reduced32);
diff --git a/nntrainer/tensor/tensor.h b/nntrainer/tensor/tensor.h
index 211334da40..a8c5b1cdb6 100644
--- a/nntrainer/tensor/tensor.h
+++ b/nntrainer/tensor/tensor.h
@@ -2014,6 +2014,20 @@ class Tensor {
 
     scale_factors_fp16 = scales;
   }
+
+  /**
+   * @brief Get the Single Tensor object
+   *
+   * @param input
+   * @return Tensor
+   */
+  Tensor getSingleTensor() const {
+    TensorDim output_dim = getDim();
+    output_dim.setDataType(ml::train::TensorDim::DataType::FP32);
+    Tensor output(output_dim, true);
+    output.copyData(*this);
+    return output;
+  }
 #endif
 
   /**