Skip to content

Commit

Permalink
Mixed precision for bn layer
Browse files Browse the repository at this point in the history
- According to recent papers in mixed precision, when it comes to computation of statistics in bn layer, we should use fp32 values
- input(fp16) <-> BN_layer(fp16 weight, but compute with fp32 and re-cast to fp16) <-> output(fp16)
- Although we need a bulky code block for this, I believe we can revisit here for cleaner code when TensorV2 becomes official

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <[email protected]>
  • Loading branch information
skykongkong8 authored and jihochu committed Feb 21, 2024
1 parent 68dd18b commit 24cad9d
Showing 1 changed file with 251 additions and 49 deletions.
300 changes: 251 additions & 49 deletions nntrainer/layers/bn_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,32 +174,125 @@ void BatchNormalizationLayer::forwarding(RunLayerContext &context,
/** use hidden_ as temporary tensor before setting the result in hidden */
Tensor t_full = hidden_;
Tensor &cvar = context.getTensor(wt_idx[BNParams::cvar]);
if (input_.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
TensorDim mu_dim = mu.getDim();
mu_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor mu32(mu_dim, true);
mu32.copyData(mu);

TensorDim var_dim = var.getDim();
var_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor var32(var_dim, true);
var32.copyData(var);

TensorDim gamma_dim = gamma.getDim();
gamma_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor gamma32(gamma_dim, true);
gamma32.copyData(gamma);

TensorDim beta_dim = beta.getDim();
beta_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor beta32(beta_dim, true);
beta32.copyData(beta);

TensorDim input_dim = input_.getDim();
input_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor input_32(input_dim, true);
input_32.copyData(input_);

TensorDim hidden_dim = hidden_.getDim();
hidden_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor hidden_32(hidden_dim, true);
hidden_32.copyData(hidden_);
Tensor t_full32 = hidden_32;

TensorDim deviation_dim = deviation.getDim();
deviation_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor deviation32(deviation_dim, true);
deviation32.copyData(deviation);

TensorDim dim_invstd = invstd.getDim();
dim_invstd.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor invstd32(dim_invstd, true);
invstd32.copyData(invstd);

TensorDim t_reduced_dim = t_reduced.getDim();
t_reduced_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor t_reduced32(t_reduced_dim, true);
t_reduced32.copyData(t_reduced);

TensorDim cvar_dim = cvar.getDim();
cvar_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor cvar32(cvar_dim, true);
cvar32.copyData(cvar);

if (training) {
input_32.average(axes_to_reduce, t_reduced32);
input_32.subtract(t_reduced32, deviation32);

mu32.multiply_i(momentum);
mu32.add_i(t_reduced32, 1 - momentum);

deviation32.pow(2.0f, t_full32);
t_full32.average(axes_to_reduce, cvar32);

var32.multiply_i(momentum);
var32.add_i(cvar32, 1 - momentum);

cvar32.add_i(epsilon);
cvar32.pow(-0.5f, invstd32);
} else {
input_32.subtract(mu32, deviation32);
/** @todo do below 2 lines only for first iteration */
var32.add(epsilon, invstd32);
invstd32.pow_i(-0.5f);
}

if (training) {
input_.average(axes_to_reduce, t_reduced);
input_.subtract(t_reduced, deviation);

mu.multiply_i(momentum);
mu.add_i(t_reduced, 1 - momentum);

deviation.pow(2.0f, t_full);
t_full.average(axes_to_reduce, cvar);

var.multiply_i(momentum);
var.add_i(cvar, 1 - momentum);

cvar.add_i(epsilon);
cvar.pow(-0.5f, invstd);
deviation32.multiply(invstd32, hidden_32);
hidden_32.multiply_i(gamma32);
hidden_32.add_i(beta32);

mu.copyData(mu32);
var.copyData(var32);
gamma.copyData(gamma32);
beta.copyData(beta32);
input_.copyData(input_32);
hidden_.copyData(hidden_32);
deviation.copyData(deviation32);
invstd.copyData(invstd32);
t_reduced.copyData(t_reduced32);
cvar.copyData(cvar32);
#else
throw std::runtime_error("enable-fp16 is not enabled");
#endif
} else {
input_.subtract(mu, deviation);
/** @todo do below 2 lines only for first iteration */
var.add(epsilon, invstd);
invstd.pow_i(-0.5f);
}
if (training) {
input_.average(axes_to_reduce, t_reduced);
input_.subtract(t_reduced, deviation);

mu.multiply_i(momentum);
mu.add_i(t_reduced, 1 - momentum);

deviation.pow(2.0f, t_full);
t_full.average(axes_to_reduce, cvar);

var.multiply_i(momentum);
var.add_i(cvar, 1 - momentum);

cvar.add_i(epsilon);
cvar.pow(-0.5f, invstd);
} else {
input_.subtract(mu, deviation);
/** @todo do below 2 lines only for first iteration */
var.add(epsilon, invstd);
invstd.pow_i(-0.5f);
}

deviation.multiply(invstd, hidden_);
hidden_.multiply_i(gamma);
hidden_.add_i(beta);
deviation.multiply(invstd, hidden_);
hidden_.multiply_i(gamma);
hidden_.add_i(beta);
}
}

void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) {
Expand All @@ -213,42 +306,151 @@ void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) {

Tensor &t_reduced = context.getTensor(wt_idx[BNParams::t_reduced]);
Tensor &t_full = context.getTensor(wt_idx[BNParams::t_full]);
if (deriv.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
TensorDim gamma_dim = gamma.getDim();
gamma_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor gamma32(gamma_dim, true);
gamma32.copyData(gamma);

TensorDim deriv_dim = deriv.getDim();
deriv_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor deriv32(deriv_dim, true);
deriv32.copyData(deriv);

TensorDim dx_dim = dx.getDim();
dx_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor dx32(dx_dim, true);
dx32.copyData(dx);

TensorDim deviation_dim = deviation.getDim();
deviation_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor deviation32(deviation_dim, true);
deviation32.copyData(deviation);

TensorDim invstd_dim = invstd.getDim();
invstd_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor invstd32(invstd_dim, true);
invstd32.copyData(invstd);

TensorDim cvar_dim = cvar.getDim();
cvar_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor cvar32(cvar_dim, true);
cvar32.copyData(cvar);

TensorDim t_reduced_dim = t_reduced.getDim();
t_reduced_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor t_reduced32(t_reduced_dim, true);
t_reduced32.copyData(t_reduced);

TensorDim t_full_dim = t_full.getDim();
t_full_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor t_full32(t_full_dim, true);
t_full32.copyData(t_full);

deviation32.multiply(deriv32, t_full32);
t_full32.average(axes_to_reduce, t_reduced32);
t_reduced32.divide_i(cvar32);
deviation32.multiply_i(t_reduced32);

if (context.getTrainable()) {
/**
* This calculates dgamma tensor.
*/
Tensor &dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]);
TensorDim dgamma_dim = dgamma.getDim();
dgamma_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor dgamma32(dgamma_dim, true);
dgamma32.copyData(dgamma);

t_full32.multiply_i(invstd32);
t_full32.sum(axes_to_reduce, dgamma32);
dgamma.copyData(dgamma32);

/**
* This implementation depends on the pre-calculated dbeta calculated.
*/
Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]);
TensorDim dbeta_dim = dbeta.getDim();
dbeta_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor dbeta32(dbeta_dim, true);
dbeta32.copyData(dbeta);
dbeta32.divide(divider, t_reduced32);
} else {
deriv32.average(axes_to_reduce, t_reduced32);
}

deviation.multiply(deriv, t_full);
t_full.average(axes_to_reduce, t_reduced);
t_reduced.divide_i(cvar);
deviation.multiply_i(t_reduced);

if (context.getTrainable()) {
/**
* This calculates dgamma tensor.
*/
Tensor &dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]);
t_full.multiply_i(invstd);
t_full.sum(axes_to_reduce, dgamma);

/**
* This implementation depends on the pre-calculated dbeta calculated.
*/
Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]);
dbeta.divide(divider, t_reduced);
deriv32.subtract(t_reduced32, dx32);
dx32.subtract_i(deviation32);

invstd32.multiply_i(gamma32);
dx32.multiply_i(invstd32);

gamma.copyData(gamma32);
dx.copyData(dx32);
deviation.copyData(deviation32);
invstd.copyData(invstd32);
cvar.copyData(cvar32);
t_reduced.copyData(t_reduced32);
t_full.copyData(t_full32);
#else
throw std::runtime_error("enable-fp16 is not enabled");
#endif
} else {
deriv.average(axes_to_reduce, t_reduced);
}
deviation.multiply(deriv, t_full);
t_full.average(axes_to_reduce, t_reduced);
t_reduced.divide_i(cvar);
deviation.multiply_i(t_reduced);

if (context.getTrainable()) {
/**
* This calculates dgamma tensor.
*/
Tensor &dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]);
t_full.multiply_i(invstd);
t_full.sum(axes_to_reduce, dgamma);

/**
* This implementation depends on the pre-calculated dbeta calculated.
*/
Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]);
dbeta.divide(divider, t_reduced);
} else {
deriv.average(axes_to_reduce, t_reduced);
}

deriv.subtract(t_reduced, dx);
dx.subtract_i(deviation);
deriv.subtract(t_reduced, dx);
dx.subtract_i(deviation);

invstd.multiply_i(gamma);
dx.multiply_i(invstd);
invstd.multiply_i(gamma);
dx.multiply_i(invstd);
}
}

void BatchNormalizationLayer::calcGradient(RunLayerContext &context) {
/** dgamma is calculated in calcDerivative. dbeta is calculated here */
Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]);
const Tensor &deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX);

deriv.sum(axes_to_reduce, dbeta);
if (deriv.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
TensorDim dbeta_dim = dbeta.getDim();
dbeta_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor dbeta32(dbeta_dim, true);
dbeta32.copyData(dbeta);

TensorDim deriv_dim = deriv.getDim();
deriv_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor deriv32(deriv_dim, true);
deriv32.copyData(deriv);

deriv32.sum(axes_to_reduce, dbeta32);
dbeta.copyData(dbeta32);
#else
throw std::runtime_error("enable-fp16 is not enabled");
#endif
} else {
deriv.sum(axes_to_reduce, dbeta);
}
}

void BatchNormalizationLayer::exportTo(
Expand Down

0 comments on commit 24cad9d

Please sign in to comment.