Skip to content

Commit

Permalink
Modify bn layer for mixed precision
Browse files Browse the repository at this point in the history
bn layer is modified for supporting mixed data type.

Signed-off-by: Jiho Chu <[email protected]>
  • Loading branch information
jihochu committed Feb 21, 2024
1 parent 24cad9d commit 8bef976
Showing 1 changed file with 77 additions and 176 deletions.
253 changes: 77 additions & 176 deletions nntrainer/layers/bn_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,98 +174,53 @@ void BatchNormalizationLayer::forwarding(RunLayerContext &context,
/** use hidden_ as temporary tensor before setting the result in hidden */
Tensor t_full = hidden_;
Tensor &cvar = context.getTensor(wt_idx[BNParams::cvar]);
if (input_.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
TensorDim mu_dim = mu.getDim();
mu_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor mu32(mu_dim, true);
mu32.copyData(mu);

TensorDim var_dim = var.getDim();
var_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor var32(var_dim, true);
var32.copyData(var);

TensorDim gamma_dim = gamma.getDim();
gamma_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor gamma32(gamma_dim, true);
gamma32.copyData(gamma);

TensorDim beta_dim = beta.getDim();
beta_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor beta32(beta_dim, true);
beta32.copyData(beta);

TensorDim input_dim = input_.getDim();
input_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor input_32(input_dim, true);
input_32.copyData(input_);

TensorDim hidden_dim = hidden_.getDim();
hidden_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor hidden_32(hidden_dim, true);
hidden_32.copyData(hidden_);
Tensor t_full32 = hidden_32;

TensorDim deviation_dim = deviation.getDim();
deviation_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor deviation32(deviation_dim, true);
deviation32.copyData(deviation);

TensorDim dim_invstd = invstd.getDim();
dim_invstd.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor invstd32(dim_invstd, true);
invstd32.copyData(invstd);

TensorDim t_reduced_dim = t_reduced.getDim();
t_reduced_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor t_reduced32(t_reduced_dim, true);
t_reduced32.copyData(t_reduced);

TensorDim cvar_dim = cvar.getDim();
cvar_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor cvar32(cvar_dim, true);
cvar32.copyData(cvar);

const auto &in_type = input_.getDataType();
if (in_type != mu.getDataType()) {
// It calculates with activation data type
Tensor mu_ = mu.clone(in_type);
Tensor var_ = var.clone(in_type);
Tensor gamma_ = gamma.clone(in_type);
Tensor beta_ = beta.clone(in_type);
Tensor deviation_ = deviation.clone(in_type);
Tensor invstd_ = invstd.clone(in_type);
Tensor t_reduced_ = t_reduced.clone(in_type);
Tensor cvar_ = cvar.clone(in_type);

if (training) {
input_32.average(axes_to_reduce, t_reduced32);
input_32.subtract(t_reduced32, deviation32);
input_.average(axes_to_reduce, t_reduced_);
input_.subtract(t_reduced_, deviation_);

mu32.multiply_i(momentum);
mu32.add_i(t_reduced32, 1 - momentum);
mu_.multiply_i(momentum);
mu_.add_i(t_reduced_, 1 - momentum);

deviation32.pow(2.0f, t_full32);
t_full32.average(axes_to_reduce, cvar32);
deviation_.pow(2.0f, t_full);
t_full.average(axes_to_reduce, cvar_);

var32.multiply_i(momentum);
var32.add_i(cvar32, 1 - momentum);
var_.multiply_i(momentum);
var_.add_i(cvar_, 1 - momentum);

cvar32.add_i(epsilon);
cvar32.pow(-0.5f, invstd32);
cvar_.add_i(epsilon);
cvar_.pow(-0.5f, invstd_);
} else {
input_32.subtract(mu32, deviation32);
input_.subtract(mu_, deviation_);
/** @todo do below 2 lines only for first iteration */
var32.add(epsilon, invstd32);
invstd32.pow_i(-0.5f);
var_.add(epsilon, invstd_);
invstd_.pow_i(-0.5f);
}

deviation32.multiply(invstd32, hidden_32);
hidden_32.multiply_i(gamma32);
hidden_32.add_i(beta32);

mu.copyData(mu32);
var.copyData(var32);
gamma.copyData(gamma32);
beta.copyData(beta32);
input_.copyData(input_32);
hidden_.copyData(hidden_32);
deviation.copyData(deviation32);
invstd.copyData(invstd32);
t_reduced.copyData(t_reduced32);
cvar.copyData(cvar32);
#else
throw std::runtime_error("enable-fp16 is not enabled");
#endif
deviation_.multiply(invstd_, hidden_);
hidden_.multiply_i(gamma_);
hidden_.add_i(beta_);

mu.copyData(mu_);
var.copyData(var_);
gamma.copyData(gamma_);
beta.copyData(beta_);
deviation.copyData(deviation_);
invstd.copyData(invstd_);
t_reduced.copyData(t_reduced_);
cvar.copyData(cvar_);
} else {
if (training) {
input_.average(axes_to_reduce, t_reduced);
Expand Down Expand Up @@ -306,96 +261,53 @@ void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) {

Tensor &t_reduced = context.getTensor(wt_idx[BNParams::t_reduced]);
Tensor &t_full = context.getTensor(wt_idx[BNParams::t_full]);
if (deriv.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
TensorDim gamma_dim = gamma.getDim();
gamma_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor gamma32(gamma_dim, true);
gamma32.copyData(gamma);

TensorDim deriv_dim = deriv.getDim();
deriv_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor deriv32(deriv_dim, true);
deriv32.copyData(deriv);

TensorDim dx_dim = dx.getDim();
dx_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor dx32(dx_dim, true);
dx32.copyData(dx);

TensorDim deviation_dim = deviation.getDim();
deviation_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor deviation32(deviation_dim, true);
deviation32.copyData(deviation);

TensorDim invstd_dim = invstd.getDim();
invstd_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor invstd32(invstd_dim, true);
invstd32.copyData(invstd);

TensorDim cvar_dim = cvar.getDim();
cvar_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor cvar32(cvar_dim, true);
cvar32.copyData(cvar);

TensorDim t_reduced_dim = t_reduced.getDim();
t_reduced_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor t_reduced32(t_reduced_dim, true);
t_reduced32.copyData(t_reduced);

TensorDim t_full_dim = t_full.getDim();
t_full_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor t_full32(t_full_dim, true);
t_full32.copyData(t_full);

deviation32.multiply(deriv32, t_full32);
t_full32.average(axes_to_reduce, t_reduced32);
t_reduced32.divide_i(cvar32);
deviation32.multiply_i(t_reduced32);

const auto &deriv_type = deriv.getDataType();

if (deriv_type != gamma.getDataType()) {
Tensor gamma_ = gamma.clone(deriv_type);
Tensor deviation_ = deviation.clone(deriv_type);
Tensor invstd_ = invstd.clone(deriv_type);
Tensor cvar_ = cvar.clone(deriv_type);
Tensor t_reduced_ = t_reduced.clone(deriv_type);
Tensor t_full_ = t_full.clone(deriv_type);

deviation_.multiply(deriv, t_full_);
t_full_.average(axes_to_reduce, t_reduced_);
t_reduced_.divide_i(cvar_);
deviation_.multiply_i(t_reduced_);

if (context.getTrainable()) {
/**
* This calculates dgamma tensor.
*/
Tensor &dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]);
TensorDim dgamma_dim = dgamma.getDim();
dgamma_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor dgamma32(dgamma_dim, true);
dgamma32.copyData(dgamma);

t_full32.multiply_i(invstd32);
t_full32.sum(axes_to_reduce, dgamma32);
dgamma.copyData(dgamma32);
Tensor dgamma_ = dgamma.clone(deriv_type);
t_full_.multiply_i(invstd_);
t_full_.sum(axes_to_reduce, dgamma_);
dgamma.copyData(dgamma_);

/**
* This implementation depends on the pre-calculated dbeta calculated.
*/
Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]);
TensorDim dbeta_dim = dbeta.getDim();
dbeta_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor dbeta32(dbeta_dim, true);
dbeta32.copyData(dbeta);
dbeta32.divide(divider, t_reduced32);
Tensor dbeta_ = dbeta.clone(deriv_type);
dbeta_.divide(divider, t_reduced_);
} else {
deriv32.average(axes_to_reduce, t_reduced32);
deriv.average(axes_to_reduce, t_reduced_);
}

deriv32.subtract(t_reduced32, dx32);
dx32.subtract_i(deviation32);

invstd32.multiply_i(gamma32);
dx32.multiply_i(invstd32);

gamma.copyData(gamma32);
dx.copyData(dx32);
deviation.copyData(deviation32);
invstd.copyData(invstd32);
cvar.copyData(cvar32);
t_reduced.copyData(t_reduced32);
t_full.copyData(t_full32);
#else
throw std::runtime_error("enable-fp16 is not enabled");
#endif
deriv.subtract(t_reduced_, dx);
dx.subtract_i(deviation_);

invstd_.multiply_i(gamma_);

gamma.copyData(gamma_);
deviation.copyData(deviation_);
invstd.copyData(invstd_);
cvar.copyData(cvar_);
t_reduced.copyData(t_reduced_);
t_full.copyData(t_full_);
} else {
deviation.multiply(deriv, t_full);
t_full.average(axes_to_reduce, t_reduced);
Expand Down Expand Up @@ -431,25 +343,14 @@ void BatchNormalizationLayer::calcGradient(RunLayerContext &context) {
/** dgamma is calculated in calcDerivative. dbeta is calculated here */
Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]);
const Tensor &deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX);
if (deriv.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
TensorDim dbeta_dim = dbeta.getDim();
dbeta_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor dbeta32(dbeta_dim, true);
dbeta32.copyData(dbeta);

TensorDim deriv_dim = deriv.getDim();
deriv_dim.setDataType(ml::train::TensorDim::DataType::FP32);
Tensor deriv32(deriv_dim, true);
deriv32.copyData(deriv);

deriv32.sum(axes_to_reduce, dbeta32);
dbeta.copyData(dbeta32);
#else
throw std::runtime_error("enable-fp16 is not enabled");
#endif
} else {

const auto &deriv_type = deriv.getDataType();
if (deriv_type == dbeta.getDataType()) {
deriv.sum(axes_to_reduce, dbeta);
} else {
Tensor dbeta_ = dbeta.clone(deriv_type);
deriv.sum(axes_to_reduce, dbeta_);
dbeta.copyData(dbeta_);
}
}

Expand Down

0 comments on commit 8bef976

Please sign in to comment.