Skip to content

Commit

Permalink
[ Trivial ] Shorter code for Half-precision BN layer
Browse files Browse the repository at this point in the history
- For mixed precision training computation of BN layer, we have been declaring temporal single-precision Tensors, compute it with Tensor ops, and save the result in previous half-precision.
- To remove redundant code, declare temporal function and reuse.
- We should revisit here fore even clearer code when TensorV2 refactorization is finished

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <[email protected]>
  • Loading branch information
skykongkong8 committed Feb 26, 2024
1 parent da40de1 commit 534d25c
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 99 deletions.
121 changes: 22 additions & 99 deletions nntrainer/layers/bn_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,56 +176,17 @@ void BatchNormalizationLayer::forwarding(RunLayerContext &context,
Tensor &cvar = context.getTensor(wt_idx[BNParams::cvar]);
if (input_.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
TensorDim mu_dim = mu.getDim();
mu_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor mu32(mu_dim, true);
mu32.copyData(mu);

TensorDim var_dim = var.getDim();
var_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor var32(var_dim, true);
var32.copyData(var);

TensorDim gamma_dim = gamma.getDim();
gamma_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor gamma32(gamma_dim, true);
gamma32.copyData(gamma);

TensorDim beta_dim = beta.getDim();
beta_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor beta32(beta_dim, true);
beta32.copyData(beta);

TensorDim input_dim = input_.getDim();
input_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor input_32(input_dim, true);
input_32.copyData(input_);

TensorDim hidden_dim = hidden_.getDim();
hidden_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor hidden_32(hidden_dim, true);
hidden_32.copyData(hidden_);
Tensor t_full32 = hidden_32;

TensorDim deviation_dim = deviation.getDim();
deviation_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor deviation32(deviation_dim, true);
deviation32.copyData(deviation);

TensorDim dim_invstd = invstd.getDim();
dim_invstd.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor invstd32(dim_invstd, true);
invstd32.copyData(invstd);

TensorDim t_reduced_dim = t_reduced.getDim();
t_reduced_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor t_reduced32(t_reduced_dim, true);
t_reduced32.copyData(t_reduced);

TensorDim cvar_dim = cvar.getDim();
cvar_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor cvar32(cvar_dim, true);
cvar32.copyData(cvar);
Tensor mu32 = mu.getSingleTensor();
Tensor var32 = var.getSingleTensor();
Tensor gamma32 = gamma.getSingleTensor();
Tensor beta32 = beta.getSingleTensor();
Tensor input_32 = input_.getSingleTensor();
Tensor hidden_32 = hidden_.getSingleTensor();
Tensor t_full32 = hidden_32;
Tensor deviation32 = deviation.getSingleTensor();
Tensor invstd32 = invstd.getSingleTensor();
Tensor t_reduced32 = t_reduced.getSingleTensor();
Tensor cvar32 = cvar.getSingleTensor();

if (training) {
input_32.average(axes_to_reduce, t_reduced32);
Expand Down Expand Up @@ -308,45 +269,14 @@ void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) {
Tensor &t_full = context.getTensor(wt_idx[BNParams::t_full]);
if (deriv.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
TensorDim gamma_dim = gamma.getDim();
gamma_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor gamma32(gamma_dim, true);
gamma32.copyData(gamma);

TensorDim deriv_dim = deriv.getDim();
deriv_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor deriv32(deriv_dim, true);
deriv32.copyData(deriv);

TensorDim dx_dim = dx.getDim();
dx_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor dx32(dx_dim, true);
dx32.copyData(dx);

TensorDim deviation_dim = deviation.getDim();
deviation_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor deviation32(deviation_dim, true);
deviation32.copyData(deviation);

TensorDim invstd_dim = invstd.getDim();
invstd_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor invstd32(invstd_dim, true);
invstd32.copyData(invstd);

TensorDim cvar_dim = cvar.getDim();
cvar_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor cvar32(cvar_dim, true);
cvar32.copyData(cvar);

TensorDim t_reduced_dim = t_reduced.getDim();
t_reduced_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor t_reduced32(t_reduced_dim, true);
t_reduced32.copyData(t_reduced);

TensorDim t_full_dim = t_full.getDim();
t_full_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor t_full32(t_full_dim, true);
t_full32.copyData(t_full);
Tensor gamma32 = gamma.getSingleTensor();
Tensor deriv32 = deriv.getSingleTensor();
Tensor dx32 = dx.getSingleTensor();
Tensor deviation32 = deviation.getSingleTensor();
Tensor invstd32 = invstd.getSingleTensor();
Tensor cvar32 = cvar.getSingleTensor();
Tensor t_reduced32 = t_reduced.getSingleTensor();
Tensor t_full32 = t_full.getSingleTensor();

deviation32.multiply(deriv32, t_full32);
t_full32.average(axes_to_reduce, t_reduced32);
Expand All @@ -357,12 +287,8 @@ void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) {
/**
* This calculates dgamma tensor.
*/
Tensor &dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]);
TensorDim dgamma_dim = dgamma.getDim();
dgamma_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor dgamma32(dgamma_dim, true);
dgamma32.copyData(dgamma);

Tensor dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]);
Tensor dgamma32 = dgamma.getSingleTensor();
t_full32.multiply_i(invstd32);
t_full32.sum(axes_to_reduce, dgamma32);
dgamma.copyData(dgamma32);
Expand All @@ -371,10 +297,7 @@ void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) {
* This implementation depends on the pre-calculated dbeta calculated.
*/
Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]);
TensorDim dbeta_dim = dbeta.getDim();
dbeta_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor dbeta32(dbeta_dim, true);
dbeta32.copyData(dbeta);
Tensor dbeta32 = dbeta.getSingleTensor();
dbeta32.divide(divider, t_reduced32);
} else {
deriv32.average(axes_to_reduce, t_reduced32);
Expand Down
14 changes: 14 additions & 0 deletions nntrainer/tensor/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -2014,6 +2014,20 @@ class Tensor {

scale_factors_fp16 = scales;
}

/**
* @brief Get the Single Tensor object
*
* @param input
* @return Tensor
*/
Tensor getSingleTensor() const {
TensorDim output_dim = getDim();
output_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor output(output_dim, true);
output.copyData(*this);
return output;
}
#endif

/**
Expand Down

0 comments on commit 534d25c

Please sign in to comment.