Skip to content

Commit

Permalink
Modify bn layer for mixed precision
Browse files Browse the repository at this point in the history
Signed-off-by: Jiho Chu <[email protected]>
  • Loading branch information
jihochu committed Feb 21, 2024
1 parent 24cad9d commit 2d15c67
Showing 1 changed file with 77 additions and 176 deletions.
253 changes: 77 additions & 176 deletions nntrainer/layers/bn_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,98 +174,53 @@ void BatchNormalizationLayer::forwarding(RunLayerContext &context,
/** use hidden_ as temporary tensor before setting the result in hidden */
Tensor t_full = hidden_;
Tensor &cvar = context.getTensor(wt_idx[BNParams::cvar]);
if (input_.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
TensorDim mu_dim = mu.getDim();
mu_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor mu32(mu_dim, true);
mu32.copyData(mu);

TensorDim var_dim = var.getDim();
var_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor var32(var_dim, true);
var32.copyData(var);

TensorDim gamma_dim = gamma.getDim();
gamma_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor gamma32(gamma_dim, true);
gamma32.copyData(gamma);

TensorDim beta_dim = beta.getDim();
beta_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor beta32(beta_dim, true);
beta32.copyData(beta);

TensorDim input_dim = input_.getDim();
input_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor input_32(input_dim, true);
input_32.copyData(input_);

TensorDim hidden_dim = hidden_.getDim();
hidden_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor hidden_32(hidden_dim, true);
hidden_32.copyData(hidden_);
Tensor t_full32 = hidden_32;

TensorDim deviation_dim = deviation.getDim();
deviation_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor deviation32(deviation_dim, true);
deviation32.copyData(deviation);

TensorDim dim_invstd = invstd.getDim();
dim_invstd.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor invstd32(dim_invstd, true);
invstd32.copyData(invstd);

TensorDim t_reduced_dim = t_reduced.getDim();
t_reduced_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor t_reduced32(t_reduced_dim, true);
t_reduced32.copyData(t_reduced);

TensorDim cvar_dim = cvar.getDim();
cvar_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor cvar32(cvar_dim, true);
cvar32.copyData(cvar);

const auto &in_type = input_.getDataType();
if (in_type != mu.getDataType()) {
// It calculates with activation data type
Tensor mu_ = mu.clone(in_type);
Tensor var_ = var.clone(in_type);
Tensor gamma_ = gamma.clone(in_type);
Tensor beta_ = beta.clone(in_type);
Tensor deviation_ = deviation.clone(in_type);
Tensor invstd_ = invstd.clone(in_type);
Tensor t_reduced_ = t_reduced.clone(in_type);
Tensor cvar_ = cvar.clone(in_type);

if (training) {
input_32.average(axes_to_reduce, t_reduced32);
input_32.subtract(t_reduced32, deviation32);
input_.average(axes_to_reduce, t_reduced_);
input_.subtract(t_reduced_, deviation_);

mu32.multiply_i(momentum);
mu32.add_i(t_reduced32, 1 - momentum);
mu_.multiply_i(momentum);
mu_.add_i(t_reduced_, 1 - momentum);

deviation32.pow(2.0f, t_full32);
t_full32.average(axes_to_reduce, cvar32);
deviation_.pow(2.0f, t_full);
t_full.average(axes_to_reduce, cvar_);

var32.multiply_i(momentum);
var32.add_i(cvar32, 1 - momentum);
var_.multiply_i(momentum);
var_.add_i(cvar_, 1 - momentum);

cvar32.add_i(epsilon);
cvar32.pow(-0.5f, invstd32);
cvar_.add_i(epsilon);
cvar_.pow(-0.5f, invstd_);
} else {
input_32.subtract(mu32, deviation32);
input_.subtract(mu_, deviation_);
/** @todo do below 2 lines only for first iteration */
var32.add(epsilon, invstd32);
invstd32.pow_i(-0.5f);
var_.add(epsilon, invstd_);
invstd_.pow_i(-0.5f);
}

deviation32.multiply(invstd32, hidden_32);
hidden_32.multiply_i(gamma32);
hidden_32.add_i(beta32);

mu.copyData(mu32);
var.copyData(var32);
gamma.copyData(gamma32);
beta.copyData(beta32);
input_.copyData(input_32);
hidden_.copyData(hidden_32);
deviation.copyData(deviation32);
invstd.copyData(invstd32);
t_reduced.copyData(t_reduced32);
cvar.copyData(cvar32);
#else
throw std::runtime_error("enable-fp16 is not enabled");
#endif
deviation_.multiply(invstd_, hidden_);
hidden_.multiply_i(gamma_);
hidden_.add_i(beta_);

mu.copyData(mu_);
var.copyData(var_);
gamma.copyData(gamma_);
beta.copyData(beta_);
deviation.copyData(deviation_);
invstd.copyData(invstd_);
t_reduced.copyData(t_reduced_);
cvar.copyData(cvar_);
} else {
if (training) {
input_.average(axes_to_reduce, t_reduced);
Expand Down Expand Up @@ -306,96 +261,53 @@ void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) {

Tensor &t_reduced = context.getTensor(wt_idx[BNParams::t_reduced]);
Tensor &t_full = context.getTensor(wt_idx[BNParams::t_full]);
if (deriv.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
TensorDim gamma_dim = gamma.getDim();
gamma_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor gamma32(gamma_dim, true);
gamma32.copyData(gamma);

TensorDim deriv_dim = deriv.getDim();
deriv_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor deriv32(deriv_dim, true);
deriv32.copyData(deriv);

TensorDim dx_dim = dx.getDim();
dx_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor dx32(dx_dim, true);
dx32.copyData(dx);

TensorDim deviation_dim = deviation.getDim();
deviation_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor deviation32(deviation_dim, true);
deviation32.copyData(deviation);

TensorDim invstd_dim = invstd.getDim();
invstd_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor invstd32(invstd_dim, true);
invstd32.copyData(invstd);

TensorDim cvar_dim = cvar.getDim();
cvar_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor cvar32(cvar_dim, true);
cvar32.copyData(cvar);

TensorDim t_reduced_dim = t_reduced.getDim();
t_reduced_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor t_reduced32(t_reduced_dim, true);
t_reduced32.copyData(t_reduced);

TensorDim t_full_dim = t_full.getDim();
t_full_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor t_full32(t_full_dim, true);
t_full32.copyData(t_full);

deviation32.multiply(deriv32, t_full32);
t_full32.average(axes_to_reduce, t_reduced32);
t_reduced32.divide_i(cvar32);
deviation32.multiply_i(t_reduced32);

const auto &deriv_type = deriv.getDataType();

if (deriv_type != gamma.getDataType()) {
Tensor gamma_ = gamma.clone(deriv_type);
Tensor deviation_ = deviation.clone(deriv_type);
Tensor invstd_ = invstd.clone(deriv_type);
Tensor cvar_ = cvar.clone(deriv_type);
Tensor t_reduced_ = t_reduced.clone(deriv_type);
Tensor t_full_ = t_full.clone(deriv_type);

deviation_.multiply(deriv, t_full_);
t_full_.average(axes_to_reduce, t_reduced_);
t_reduced_.divide_i(cvar_);
deviation_.multiply_i(t_reduced_);

if (context.getTrainable()) {
/**
* This calculates dgamma tensor.
*/
Tensor &dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]);
TensorDim dgamma_dim = dgamma.getDim();
dgamma_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor dgamma32(dgamma_dim, true);
dgamma32.copyData(dgamma);

t_full32.multiply_i(invstd32);
t_full32.sum(axes_to_reduce, dgamma32);
dgamma.copyData(dgamma32);
Tensor dgamma_ = dgamma.clone(deriv_type);
t_full_.multiply_i(invstd_);
t_full_.sum(axes_to_reduce, dgamma_);
dgamma.copyData(dgamma_);

/**
* This implementation depends on the pre-calculated dbeta calculated.
*/
Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]);
TensorDim dbeta_dim = dbeta.getDim();
dbeta_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor dbeta32(dbeta_dim, true);
dbeta32.copyData(dbeta);
dbeta32.divide(divider, t_reduced32);
Tensor dbeta_ = dbeta.clone(deriv_type);
dbeta_.divide(divider, t_reduced_);
} else {
deriv32.average(axes_to_reduce, t_reduced32);
deriv.average(axes_to_reduce, t_reduced_);
}

deriv32.subtract(t_reduced32, dx32);
dx32.subtract_i(deviation32);

invstd32.multiply_i(gamma32);
dx32.multiply_i(invstd32);

gamma.copyData(gamma32);
dx.copyData(dx32);
deviation.copyData(deviation32);
invstd.copyData(invstd32);
cvar.copyData(cvar32);
t_reduced.copyData(t_reduced32);
t_full.copyData(t_full32);
#else
throw std::runtime_error("enable-fp16 is not enabled");
#endif
deriv.subtract(t_reduced_, dx);
dx.subtract_i(deviation_);

invstd_.multiply_i(gamma_);

gamma.copyData(gamma_);
deviation.copyData(deviation_);
invstd.copyData(invstd_);
cvar.copyData(cvar_);
t_reduced.copyData(t_reduced_);
t_full.copyData(t_full_);
} else {
deviation.multiply(deriv, t_full);
t_full.average(axes_to_reduce, t_reduced);
Expand Down Expand Up @@ -431,25 +343,14 @@ void BatchNormalizationLayer::calcGradient(RunLayerContext &context) {
/** dgamma is calculated in calcDerivative. dbeta is calculated here */
Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]);
const Tensor &deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX);
if (deriv.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
TensorDim dbeta_dim = dbeta.getDim();
dbeta_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor dbeta32(dbeta_dim, true);
dbeta32.copyData(dbeta);

TensorDim deriv_dim = deriv.getDim();
deriv_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor deriv32(deriv_dim, true);
deriv32.copyData(deriv);

deriv32.sum(axes_to_reduce, dbeta32);
dbeta.copyData(dbeta32);
#else
throw std::runtime_error("enable-fp16 is not enabled");
#endif
} else {

const auto &deriv_type = deriv.getDataType();
if (deriv_type == dbeta.getDataType()) {
deriv.sum(axes_to_reduce, dbeta);
} else {
Tensor dbeta_ = dbeta.clone(deriv_type);
deriv.sum(axes_to_reduce, dbeta_);
dbeta.copyData(dbeta_);
}
}

Expand Down

0 comments on commit 2d15c67

Please sign in to comment.