diff --git a/nntrainer/layers/conv2d_layer.cpp b/nntrainer/layers/conv2d_layer.cpp index ff44afeaf6..ada8752f7c 100644 --- a/nntrainer/layers/conv2d_layer.cpp +++ b/nntrainer/layers/conv2d_layer.cpp @@ -118,10 +118,16 @@ static void col2im(const Tensor &col_matrix, const TensorDim &kdim, if (image.getDataType() == nntrainer::Tdatatype::FP32) { float val; apply_data(&val); - } else if (image.getDataType() == nntrainer::Tdatatype::FP16) { + } +#ifdef ENABLE_FP16 + else if (image.getDataType() == nntrainer::Tdatatype::FP16) { _FP16 val; apply_data(&val); } +#endif + else { + throw std::runtime_error("Not supported datatype"); + } } /** @@ -256,10 +262,16 @@ static void im2col(const Tensor &in, const TensorDim &kdim, if (out.getDataType() == nntrainer::Tdatatype::FP32) { float *out_data = out.getData(); apply_data(out_data); - } else if (out.getDataType() == nntrainer::Tdatatype::FP16) { + } +#ifdef ENABLE_FP16 + else if (out.getDataType() == nntrainer::Tdatatype::FP16) { _FP16 *out_data = out.getData<_FP16>(); apply_data(out_data); } +#endif + else { + throw std::runtime_error("Not supported datatype"); + } } } // namespace @@ -300,10 +312,11 @@ void Conv2DLayer::finalize(InitLayerContext &context) { auto &dilation = std::get>(conv_props); - TensorDim kernel_dim = - TensorDim(filter_size, in_dim.channel(), kernel_size[0], kernel_size[1], - in_dim.getTensorType()); - TensorDim bias_dim = TensorDim(1, filter_size, 1, 1, in_dim.getTensorType()); + auto in_t_type = in_dim.getTensorType(); + in_t_type.data_type = context.getWeightDataType(); + TensorDim kernel_dim = TensorDim(filter_size, in_dim.channel(), + kernel_size[0], kernel_size[1], in_t_type); + TensorDim bias_dim = TensorDim(1, filter_size, 1, 1, in_t_type); padding = std::get(conv_props) .compute(in_dim, kernel_dim, {stride[0], stride[1]}, @@ -347,19 +360,11 @@ void Conv2DLayer::finalize(InitLayerContext &context) { << "Failed to initialize: Calculated patch end is over int max"; } -void Conv2DLayer::forwarding(RunLayerContext &context, bool training) { - int status = ML_ERROR_NONE; - - unsigned int filter_size = std::get(conv_props); - auto &stride = std::get>(conv_props); - auto &dilation = - std::get>(conv_props); - - Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); - Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); - - Tensor &filter_kernel = context.getWeight(wt_idx[ConvParams::weight]); - +static void forwarding_internal( + Tensor &input, Tensor &hidden, Tensor &filter_kernel, Tensor &bias_kernel, + unsigned int filter_size, const std::array &padding, + const std::array &stride, + const std::array &dilation, bool enable_bias) { /** Calculate Convolution 2D * * This is the 2D Matrix Shape [ height ] x [ width ] @@ -396,8 +401,8 @@ void Conv2DLayer::forwarding(RunLayerContext &context, bool training) { * -> [Channel ( = filter_size = output_dim.channel )] * x [output_dim.height x output_dim.width] */ - const TensorDim &in_dim = input_.getDim(); - const TensorDim &out_dim = hidden_.getDim(); + const TensorDim &in_dim = input.getDim(); + const TensorDim &out_dim = hidden.getDim(); const TensorDim &filter_dim = filter_kernel.getDim(); TensorDim filter_dim_squeezed{filter_kernel.batch(), filter_kernel.getDim().getFeatureLen()}; @@ -413,9 +418,9 @@ void Conv2DLayer::forwarding(RunLayerContext &context, bool training) { Tensor result = Tensor(calcCol2ImOutputDim(out_dim, filter_dim)); result.setZero(); for (unsigned int b = s; b < e; ++b) { - Tensor out = hidden_.getBatchSlice(b, 1); + Tensor out = hidden.getBatchSlice(b, 1); out.reshape({filter_size, out_dim.width() * out_dim.height()}); - Tensor in_sub = input_.getBatchSlice(b, 1); + Tensor in_sub = input.getBatchSlice(b, 1); im2col(in_sub, filter_dim, padding, stride, dilation, result); filter_kernel.dot(result, out, false, true); @@ -432,26 +437,48 @@ void Conv2DLayer::forwarding(RunLayerContext &context, bool training) { } filter_kernel.reshape(filter_dim); - if (auto &disable_bias = std::get(*layer_impl_props); - disable_bias.empty() || disable_bias.get() == false) { - Tensor &bias_kernel = context.getWeight(wt_idx[ConvParams::bias]); - status = hidden_.add_i(bias_kernel); + if (enable_bias) { + auto status = hidden.add_i(bias_kernel); if (status != ML_ERROR_NONE) { throw std::invalid_argument("[Conv2D] adding bias failed"); } } } -void Conv2DLayer::calcDerivative(RunLayerContext &context) { +void Conv2DLayer::forwarding(RunLayerContext &context, bool training) { + int status = ML_ERROR_NONE; + unsigned int filter_size = std::get(conv_props); auto &stride = std::get>(conv_props); auto &dilation = std::get>(conv_props); - const Tensor &derivative = context.getIncomingDerivative(SINGLE_INOUT_IDX); - Tensor &input_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX); + Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); + Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); + Tensor &filter_kernel = context.getWeight(wt_idx[ConvParams::weight]); + Tensor &bias_kernel = context.getWeight(wt_idx[ConvParams::bias]); + + auto &disable_bias = std::get(*layer_impl_props); + bool enable_bias = !disable_bias.empty() && disable_bias.get() == false; + const auto &in_type = input_.getDataType(); + if (in_type == filter_kernel.getDataType()) { + forwarding_internal(input_, hidden_, filter_kernel, bias_kernel, + filter_size, padding, stride, dilation, enable_bias); + } else { + Tensor filter_kernel_ = filter_kernel.clone(in_type); + Tensor bias_kernel_ = bias_kernel.clone(in_type); + forwarding_internal(input_, hidden_, filter_kernel_, bias_kernel_, + filter_size, padding, stride, dilation, enable_bias); + } +} + +static void calcDerivative_internal( + const Tensor &derivative, Tensor &input_derivative, Tensor &filter_kernel, + unsigned int filter_size, const std::array &padding, + const std::array &stride, + const std::array &dilation) { TensorDim filter_dim = filter_kernel.getDim(); TensorDim filter_dim_squeezed{filter_kernel.batch(), filter_kernel.getDim().getFeatureLen()}; @@ -489,16 +516,36 @@ void Conv2DLayer::calcDerivative(RunLayerContext &context) { filter_kernel.reshape(filter_dim); } -void Conv2DLayer::calcGradient(RunLayerContext &context) { +void Conv2DLayer::calcDerivative(RunLayerContext &context) { unsigned int filter_size = std::get(conv_props); auto &stride = std::get>(conv_props); auto &dilation = std::get>(conv_props); const Tensor &derivative = context.getIncomingDerivative(SINGLE_INOUT_IDX); - Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); + Tensor &input_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX); + Tensor &filter_kernel = context.getWeight(wt_idx[ConvParams::weight]); - Tensor &delK = context.getWeightGrad(wt_idx[ConvParams::weight]); + const auto &deriv_type = derivative.getDataType(); + if (deriv_type == filter_kernel.getDataType()) { + // filter_kernel = filter_kernel_.clone(input_.getDataType()); + calcDerivative_internal(derivative, input_derivative, filter_kernel, + filter_size, padding, stride, dilation); + + } else { + // filter_kernel = filter_kernel_; + Tensor filter_kernel_ = filter_kernel.clone(deriv_type); + calcDerivative_internal(derivative, input_derivative, filter_kernel_, + filter_size, padding, stride, dilation); + } +} + +void calcGradient_internal( + Tensor &input, Tensor &delK, Tensor &delBias, const Tensor &derivative, + + unsigned int filter_size, const std::array &padding, + const std::array &stride, + const std::array &dilation, bool enable_bias) { delK.setZero(); TensorDim filter_dim = delK.getDim(); @@ -514,14 +561,14 @@ void Conv2DLayer::calcGradient(RunLayerContext &context) { TensorDim out_dim_squeezed{filter_size, derivative.width() * derivative.height(), - input_.getTensorType()}; - auto workers = ParallelBatch(input_.batch()); + input.getTensorType()}; + auto workers = ParallelBatch(input.batch()); /// input -(im2col)-> column_matrix -> filter x (column_matrix) = output /// so delK = dy x column_matrix ^ T; if (workers.getNumWorkers() > 1) { TensorDim delK_ext = filter_dim_squeezed; - delK_ext.batch(input_.batch()); + delK_ext.batch(input.batch()); Tensor delK_par = Tensor(delK_ext); delK_par.setZero(); @@ -536,7 +583,7 @@ void Conv2DLayer::calcGradient(RunLayerContext &context) { Tensor delK_sub = delK_par.getBatchSlice(b, 1); deriv_sub.reshape(out_dim_squeezed); - Tensor in_sub = input_.getBatchSlice(b, 1); + Tensor in_sub = input.getBatchSlice(b, 1); /** * @todo this result can be cached from the forward iteration at the @@ -553,21 +600,20 @@ void Conv2DLayer::calcGradient(RunLayerContext &context) { workers.run(); - for (unsigned int b = 0; b < input_.batch(); ++b) { + for (unsigned int b = 0; b < input.batch(); ++b) { Tensor delK_sub = delK_par.getBatchSlice(b, 1); delK.add_i(delK_sub); } - } else { Tensor result = Tensor(calcCol2ImOutputDim(derivative.getDim(), filter_dim)); result.setZero(); - for (unsigned int b = 0; b < input_.batch(); ++b) { + for (unsigned int b = 0; b < input.batch(); ++b) { Tensor deriv_sub = derivative.getBatchSlice(b, 1); deriv_sub.reshape(out_dim_squeezed); - Tensor in_sub = input_.getBatchSlice(b, 1); + Tensor in_sub = input.getBatchSlice(b, 1); /** * @todo this result can be cached from the forward iteration at the @@ -580,13 +626,40 @@ void Conv2DLayer::calcGradient(RunLayerContext &context) { result.deallocate(); } delK.reshape(filter_dim); - if (auto &disable_bias = std::get(*layer_impl_props); - disable_bias.empty() || disable_bias.get() == false) { - Tensor &delBias = context.getWeightGrad(wt_idx[ConvParams::bias]); + if (enable_bias) { derivative.sum({0, 2, 3}, delBias); } } +void Conv2DLayer::calcGradient(RunLayerContext &context) { + unsigned int filter_size = std::get(conv_props); + auto &stride = std::get>(conv_props); + auto &dilation = + std::get>(conv_props); + + const Tensor &derivative = context.getIncomingDerivative(SINGLE_INOUT_IDX); + Tensor &input = context.getInput(SINGLE_INOUT_IDX); + + Tensor &delK = context.getWeightGrad(wt_idx[ConvParams::weight]); + Tensor &delBias = context.getWeightGrad(wt_idx[ConvParams::bias]); + + auto &disable_bias = std::get(*layer_impl_props); + bool enable_bias = !disable_bias.empty() && disable_bias.get() == false; + + const auto &in_type = input.getDataType(); + if (in_type == delK.getDataType()) { + calcGradient_internal(input, delK, delBias, derivative, filter_size, + padding, stride, dilation, enable_bias); + } else { + Tensor delK_ = delK.clone(in_type); + Tensor delBias_ = delBias.clone(in_type); + calcGradient_internal(input, delK_, delBias_, derivative, filter_size, + padding, stride, dilation, enable_bias); + delK.copyData(delK_); + delBias.copyData(delBias_); + } +} + void Conv2DLayer::exportTo(Exporter &exporter, const ml::train::ExportMethods &method) const { LayerImpl::exportTo(exporter, method); diff --git a/nntrainer/layers/fc_layer.cpp b/nntrainer/layers/fc_layer.cpp index 93610e1fcc..7fbcc4c467 100644 --- a/nntrainer/layers/fc_layer.cpp +++ b/nntrainer/layers/fc_layer.cpp @@ -116,32 +116,45 @@ void FullyConnectedLayer::setProperty(const std::vector &values) { } void FullyConnectedLayer::forwarding(RunLayerContext &context, bool training) { - Tensor &weight = context.getWeight(weight_idx[FCParams::weight]); + Tensor &weight_ = context.getWeight(weight_idx[FCParams::weight]); Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); - - if (weight.getDataType() == nntrainer::Tdatatype::QINT4 || - weight.getDataType() == nntrainer::Tdatatype::QINT8) { - Tdatatype dtype = input_.getDataType(); - - Tensor weight_( - {{weight.batch(), weight.channel(), weight.height(), weight.width()}, - {weight.getFormat(), dtype}}, + auto &disable_bias = std::get(*layer_impl_props); + bool enable_bias = disable_bias.empty() || disable_bias.get() == false; + + const auto &in_type = input_.getDataType(); + if (weight_.getDataType() == nntrainer::Tdatatype::QINT4 || + weight_.getDataType() == nntrainer::Tdatatype::QINT8) { + Tensor weight( + {{weight_.batch(), weight_.channel(), weight_.height(), weight_.width()}, + {weight_.getFormat(), in_type}}, true); unsigned int axis = context.getWeightObject(weight_idx[FCParams::weight]).getOutputAxis(); - weight.dequantize(weight_, axis); - input_.dot(weight_, hidden_, false, false); - } else { + weight_.dequantize(weight, axis); input_.dot(weight, hidden_, false, false); - } + if (enable_bias) { + Tensor bias = + context.getWeight(weight_idx[FCParams::bias]).clone(in_type); + hidden_.add_i(bias); + } + } else if (in_type != weight_.getDataType()) { + Tensor weight = weight_.clone(in_type); - if (auto &disable_bias = std::get(*layer_impl_props); - disable_bias.empty() || disable_bias.get() == false) { - Tensor &bias = context.getWeight(weight_idx[FCParams::bias]); - hidden_.add_i(bias); + input_.dot(weight, hidden_, false, false); + if (enable_bias) { + Tensor bias = + context.getWeight(weight_idx[FCParams::bias]).clone(in_type); + hidden_.add_i(bias); + } + } else { + input_.dot(weight_, hidden_, false, false); + if (enable_bias) { + Tensor &bias_ = context.getWeight(weight_idx[FCParams::bias]); + hidden_.add_i(bias_); + } } } @@ -155,6 +168,8 @@ void FullyConnectedLayer::incremental_forwarding(RunLayerContext &context, Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); + auto &disable_bias = std::get(*layer_impl_props); + bool enable_bias = disable_bias.empty() || disable_bias.get() == false; TensorDim input_dim = input_.getDim(); TensorDim hidden_dim = hidden_.getDim(); @@ -172,51 +187,93 @@ void FullyConnectedLayer::incremental_forwarding(RunLayerContext &context, input_step_dim.height(to - from); hidden_step_dim.height(to - from); + const auto &in_type = input_.getDataType(); + // @todo: set reset stride as false. This implementation only works when batch // size is 1 Tensor input_step = input_.getSharedDataTensor(input_step_dim, 0, true); Tensor hidden_step = hidden_.getSharedDataTensor(hidden_step_dim, 0, true); - input_step.dot(weight, hidden_step, false, false); + if (in_type != weight.getDataType()) { + Tensor weight_ = weight.clone(in_type); + input_step.dot(weight_, hidden_step, false, false); + if (enable_bias) { + Tensor bias = + context.getWeight(weight_idx[FCParams::bias]).clone(in_type); + hidden_step.add_i(bias); + } + } else { + input_step.dot(weight, hidden_step, false, false); - if (auto &disable_bias = std::get(*layer_impl_props); - disable_bias.empty() || disable_bias.get() == false) { - Tensor &bias = context.getWeight(weight_idx[FCParams::bias]); - hidden_step.add_i(bias); + if (enable_bias) { + Tensor &bias = context.getWeight(weight_idx[FCParams::bias]); + hidden_step.add_i(bias); + } } } void FullyConnectedLayer::calcDerivative(RunLayerContext &context) { - Tensor &weight = context.getWeight(weight_idx[FCParams::weight]); + Tensor &weight_ = context.getWeight(weight_idx[FCParams::weight]); const Tensor &derivative_ = context.getIncomingDerivative(SINGLE_INOUT_IDX); Tensor &ret_ = context.getOutgoingDerivative(SINGLE_INOUT_IDX); - ret_.dot_deriv_wrt_1(weight, derivative_, false, false); + const auto &deriv_type = derivative_.getDataType(); + if (deriv_type != weight_.getDataType()) { + Tensor weight = weight_.clone(deriv_type); + ret_.dot_deriv_wrt_1(weight, derivative_, false, false); + } else { + ret_.dot_deriv_wrt_1(weight_, derivative_, false, false); + } } void FullyConnectedLayer::calcGradient(RunLayerContext &context) { - Tensor &djdw = context.getWeightGrad(weight_idx[FCParams::weight]); + Tensor &djdw_ = context.getWeightGrad(weight_idx[FCParams::weight]); const Tensor &derivative_ = context.getIncomingDerivative(SINGLE_INOUT_IDX); Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); - - if (auto &disable_bias = std::get(*layer_impl_props); - disable_bias.empty() || disable_bias.get() == false) { - Tensor &djdb = context.getWeightGrad(weight_idx[FCParams::bias]); - - if (context.isGradientFirstAccess(weight_idx[FCParams::bias])) { - derivative_.sum({0, 1, 2}, djdb); - } else { - /// @todo optimize below by adding beta to Tensor::sum - Tensor t = derivative_.sum({0, 1, 2}); - djdb.add_i(t); + auto &disable_bias = std::get(*layer_impl_props); + bool enable_bias = disable_bias.empty() || disable_bias.get() == false; + bool wg_first_access = + context.isGradientFirstAccess(weight_idx[FCParams::weight]); + + const auto &in_type = input_.getDataType(); + if (in_type != djdw_.getDataType()) { + if (enable_bias) { + Tensor &djdb_ = context.getWeightGrad(weight_idx[FCParams::bias]); + Tensor djdb = djdb_.clone(in_type); + bool b_first_access = + context.isGradientFirstAccess(weight_idx[FCParams::bias]); + + if (b_first_access) { + derivative_.sum({0, 1, 2}, djdb); + } else { + /// @todo optimize below by adding beta to Tensor::sum + Tensor t = derivative_.sum({0, 1, 2}); + djdb.add_i(t); + } + djdb_.copyData(djdb); + } + Tensor djdw = djdw_.clone(in_type); + input_.dot_deriv_wrt_2(djdw, derivative_, false, false, !wg_first_access); + djdw_.copyData(djdw); + } else { + if (enable_bias) { + Tensor &djdb_ = context.getWeightGrad(weight_idx[FCParams::bias]); + bool b_first_access = + context.isGradientFirstAccess(weight_idx[FCParams::bias]); + + if (b_first_access) { + derivative_.sum({0, 1, 2}, djdb_); + } else { + /// @todo optimize below by adding beta to Tensor::sum + Tensor t = derivative_.sum({0, 1, 2}); + djdb_.add_i(t); + } } + input_.dot_deriv_wrt_2(djdw_, derivative_, false, false, !wg_first_access); } - input_.dot_deriv_wrt_2( - djdw, derivative_, false, false, - !context.isGradientFirstAccess(weight_idx[FCParams::weight])); } } /* namespace nntrainer */ diff --git a/nntrainer/layers/loss/cross_entropy_softmax_loss_layer.cpp b/nntrainer/layers/loss/cross_entropy_softmax_loss_layer.cpp index 11d4567709..7899fa8e03 100644 --- a/nntrainer/layers/loss/cross_entropy_softmax_loss_layer.cpp +++ b/nntrainer/layers/loss/cross_entropy_softmax_loss_layer.cpp @@ -93,7 +93,8 @@ void CrossEntropySoftmaxLossLayer::calcDerivative(RunLayerContext &context) { "Error when calculating loss"); } - ret_derivative.multiply_i(loss_scale); + if (loss_scale != 0.0f) + ret_derivative.multiply_i(loss_scale); } } // namespace nntrainer diff --git a/nntrainer/layers/lstm.cpp b/nntrainer/layers/lstm.cpp index b92e313287..771d76a0be 100644 --- a/nntrainer/layers/lstm.cpp +++ b/nntrainer/layers/lstm.cpp @@ -625,91 +625,37 @@ void LSTMLayer::exportTo(Exporter &exporter, exporter.saveResult(lstm_props, method, this); } -void LSTMLayer::forwarding(RunLayerContext &context, bool training) { - const bool disable_bias = - std::get(*layer_impl_props).get(); - - const unsigned int unit = std::get(lstmcore_props).get(); - const bool integrate_bias = - std::get(lstmcore_props).get(); - - const bool return_sequences = - std::get(lstm_props).get(); - const bool bidirectional = std::get(lstm_props).get(); - const float dropout_rate = std::get(lstm_props).get(); - const unsigned int max_timestep = - std::get(lstm_props).get(); - - const unsigned int bidirectional_constant = bidirectional ? 2 : 1; - bool enable_dropout = dropout_rate > epsilon && training; +static void forwarding_internal( + LSTMLayer *layer, const bool disable_bias, const unsigned int unit, + const bool integrate_bias, const bool return_sequences, + const bool bidirectional, const float dropout_rate, + const unsigned int max_timestep, const unsigned int bidirectional_constant, + bool enable_dropout, const Tensor &input, Tensor &output, + const Tensor &weight_ih, const Tensor &weight_hh, const Tensor &bias_h, + const Tensor &bias_ih, const Tensor &bias_hh, Tensor &hidden_state, + Tensor &cell_state, Tensor &ifgo, Tensor &mask, Tensor &reverse_weight_ih, + Tensor &reverse_weight_hh, Tensor &reverse_bias_h, Tensor &reverse_bias_ih, + Tensor &reverse_bias_hh, Tensor &reverse_hidden_state, + Tensor &reverse_cell_state, Tensor &reverse_ifgo, ActiFunc &acti_func, + ActiFunc &recurrent_acti_func) { - const Tensor &input = context.getInput(SINGLE_INOUT_IDX); const TensorDim input_dim = input.getDim(); const unsigned int batch_size = input_dim.batch(); const unsigned int feature_size = input_dim.width(); - Tensor &output = context.getOutput(SINGLE_INOUT_IDX); - - const Tensor &weight_ih = context.getWeight(wt_idx[LSTMParams::weight_ih]); - const Tensor &weight_hh = context.getWeight(wt_idx[LSTMParams::weight_hh]); - - TensorDim::TensorType weight_tensor_type = weight_ih.getTensorType(); - Tensor empty; - empty.setTensorType(weight_tensor_type); - - const Tensor &bias_h = !disable_bias && integrate_bias - ? context.getWeight(wt_idx[LSTMParams::bias_h]) - : empty; - const Tensor &bias_ih = !disable_bias && !integrate_bias - ? context.getWeight(wt_idx[LSTMParams::bias_ih]) - : empty; - const Tensor &bias_hh = !disable_bias && !integrate_bias - ? context.getWeight(wt_idx[LSTMParams::bias_hh]) - : empty; - - Tensor &hidden_state = context.getTensor(wt_idx[LSTMParams::hidden_state]); - Tensor &cell_state = context.getTensor(wt_idx[LSTMParams::cell_state]); - Tensor &ifgo = context.getTensor(wt_idx[LSTMParams::ifgo]); - Tensor &mask = enable_dropout - ? context.getTensor(wt_idx[LSTMParams::dropout_mask]) - : empty; - - forwardingBatchFirstLSTM(NUM_GATE, batch_size, feature_size, disable_bias, - unit, integrate_bias, acti_func, recurrent_acti_func, - enable_dropout, dropout_rate, max_timestep, false, - input, weight_ih, weight_hh, bias_h, bias_ih, - bias_hh, hidden_state, cell_state, ifgo, mask); + layer->forwardingBatchFirstLSTM( + layer->NUM_GATE, batch_size, feature_size, disable_bias, unit, + integrate_bias, acti_func, recurrent_acti_func, enable_dropout, + dropout_rate, max_timestep, false, input, weight_ih, weight_hh, bias_h, + bias_ih, bias_hh, hidden_state, cell_state, ifgo, mask); if (bidirectional) { - const Tensor &reverse_weight_ih = - context.getWeight(wt_idx[LSTMParams::reverse_weight_ih]); - const Tensor &reverse_weight_hh = - context.getWeight(wt_idx[LSTMParams::reverse_weight_hh]); - const Tensor &reverse_bias_h = - !disable_bias && integrate_bias - ? context.getWeight(wt_idx[LSTMParams::reverse_bias_h]) - : empty; - const Tensor &reverse_bias_ih = - !disable_bias && !integrate_bias - ? context.getWeight(wt_idx[LSTMParams::reverse_bias_ih]) - : empty; - const Tensor &reverse_bias_hh = - !disable_bias && !integrate_bias - ? context.getWeight(wt_idx[LSTMParams::reverse_bias_hh]) - : empty; - - Tensor &reverse_hidden_state = - context.getTensor(wt_idx[LSTMParams::reverse_hidden_state]); - Tensor &reverse_cell_state = - context.getTensor(wt_idx[LSTMParams::reverse_cell_state]); - Tensor &reverse_ifgo = context.getTensor(wt_idx[LSTMParams::reverse_ifgo]); - - forwardingBatchFirstLSTM( - NUM_GATE, batch_size, feature_size, disable_bias, unit, integrate_bias, - acti_func, recurrent_acti_func, enable_dropout, dropout_rate, - max_timestep, true, input, reverse_weight_ih, reverse_weight_hh, - reverse_bias_h, reverse_bias_ih, reverse_bias_hh, reverse_hidden_state, - reverse_cell_state, reverse_ifgo, mask); + layer->forwardingBatchFirstLSTM( + layer->NUM_GATE, batch_size, feature_size, disable_bias, unit, + integrate_bias, acti_func, recurrent_acti_func, enable_dropout, + dropout_rate, max_timestep, true, input, reverse_weight_ih, + reverse_weight_hh, reverse_bias_h, reverse_bias_ih, reverse_bias_hh, + reverse_hidden_state, reverse_cell_state, reverse_ifgo, mask); } if (return_sequences && !bidirectional) { @@ -742,8 +688,6 @@ void LSTMLayer::forwarding(RunLayerContext &context, bool training) { std::copy(hidden_state_data, hidden_state_data + unit, output_data); if (bidirectional) { - Tensor &reverse_hidden_state = - context.getTensor(wt_idx[LSTMParams::reverse_hidden_state]); float *reverse_hidden_state_data = reverse_hidden_state.getAddress( batch * max_timestep * unit + @@ -769,8 +713,6 @@ void LSTMLayer::forwarding(RunLayerContext &context, bool training) { std::copy(hidden_state_data, hidden_state_data + unit, output_data); if (bidirectional) { - Tensor &reverse_hidden_state = - context.getTensor(wt_idx[LSTMParams::reverse_hidden_state]); _FP16 *reverse_hidden_state_data = reverse_hidden_state.getAddress<_FP16>( batch * max_timestep * unit + @@ -788,23 +730,212 @@ void LSTMLayer::forwarding(RunLayerContext &context, bool training) { } } -void LSTMLayer::calcDerivative(RunLayerContext &context) { +void LSTMLayer::forwarding(RunLayerContext &context, bool training) { + const bool disable_bias = + std::get(*layer_impl_props).get(); + const unsigned int unit = std::get(lstmcore_props).get(); + const bool integrate_bias = + std::get(lstmcore_props).get(); + const bool return_sequences = + std::get(lstm_props).get(); const bool bidirectional = std::get(lstm_props).get(); + const float dropout_rate = std::get(lstm_props).get(); + const unsigned int max_timestep = + std::get(lstm_props).get(); + const unsigned int bidirectional_constant = bidirectional ? 2 : 1; + bool enable_dropout = dropout_rate > epsilon && training; - Tensor &outgoing_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX); + const Tensor &input = context.getInput(SINGLE_INOUT_IDX); + Tensor &output = context.getOutput(SINGLE_INOUT_IDX); const Tensor &weight_ih = context.getWeight(wt_idx[LSTMParams::weight_ih]); - const Tensor &d_ifgos = context.getTensorGrad(wt_idx[LSTMParams::ifgo]); + const Tensor &weight_hh = context.getWeight(wt_idx[LSTMParams::weight_hh]); - calcDerivativeLSTM(outgoing_derivative, weight_ih, d_ifgos); + TensorDim::TensorType weight_tensor_type = weight_ih.getTensorType(); + Tensor empty; + empty.setTensorType(weight_tensor_type); + + const Tensor &bias_h = !disable_bias && integrate_bias + ? context.getWeight(wt_idx[LSTMParams::bias_h]) + : empty; + const Tensor &bias_ih = !disable_bias && !integrate_bias + ? context.getWeight(wt_idx[LSTMParams::bias_ih]) + : empty; + const Tensor &bias_hh = !disable_bias && !integrate_bias + ? context.getWeight(wt_idx[LSTMParams::bias_hh]) + : empty; + Tensor &hidden_state = context.getTensor(wt_idx[LSTMParams::hidden_state]); + Tensor &cell_state = context.getTensor(wt_idx[LSTMParams::cell_state]); + Tensor &ifgo = context.getTensor(wt_idx[LSTMParams::ifgo]); + + Tensor &mask = enable_dropout + ? context.getTensor(wt_idx[LSTMParams::dropout_mask]) + : empty; + + Tensor &reverse_weight_ih = + bidirectional ? context.getWeight(wt_idx[LSTMParams::reverse_weight_ih]) + : empty; + Tensor &reverse_weight_hh = + bidirectional ? context.getWeight(wt_idx[LSTMParams::reverse_weight_hh]) + : empty; + Tensor &reverse_bias_h = + !disable_bias && integrate_bias && bidirectional + ? context.getWeight(wt_idx[LSTMParams::reverse_bias_h]) + : empty; + Tensor &reverse_bias_ih = + !disable_bias && !integrate_bias && bidirectional + ? context.getWeight(wt_idx[LSTMParams::reverse_bias_ih]) + : empty; + Tensor &reverse_bias_hh = + !disable_bias && !integrate_bias && bidirectional + ? context.getWeight(wt_idx[LSTMParams::reverse_bias_hh]) + : empty; + + Tensor &reverse_hidden_state = + bidirectional ? context.getTensor(wt_idx[LSTMParams::reverse_hidden_state]) + : empty; + Tensor &reverse_cell_state = + bidirectional ? context.getTensor(wt_idx[LSTMParams::reverse_cell_state]) + : empty; + Tensor &reverse_ifgo = + bidirectional ? context.getTensor(wt_idx[LSTMParams::reverse_ifgo]) : empty; + + const auto in_type = input.getDataType(); + if (in_type != weight_ih.getDataType()) { + Tensor weight_ih_ = weight_ih.clone(in_type); + Tensor weight_hh_ = weight_hh.clone(in_type); + Tensor bias_h_ = bias_h.clone(in_type); + Tensor bias_ih_ = bias_ih.clone(in_type); + Tensor bias_hh_ = bias_hh.clone(in_type); + Tensor hidden_state_ = hidden_state.clone(in_type); + Tensor cell_state_ = cell_state.clone(in_type); + Tensor ifgo_ = ifgo.clone(in_type); + Tensor mask_ = mask.clone(in_type); + Tensor reverse_weight_ih_ = reverse_weight_ih.clone(in_type); + Tensor reverse_weight_hh_ = reverse_weight_hh.clone(in_type); + Tensor reverse_bias_h_ = reverse_bias_h.clone(in_type); + Tensor reverse_bias_ih_ = reverse_bias_ih.clone(in_type); + Tensor reverse_bias_hh_ = reverse_bias_hh.clone(in_type); + Tensor reverse_hidden_state_ = reverse_hidden_state.clone(in_type); + Tensor reverse_cell_state_ = reverse_cell_state.clone(in_type); + Tensor reverse_ifgo_ = reverse_ifgo.clone(in_type); + + forwarding_internal( + this, disable_bias, unit, integrate_bias, return_sequences, bidirectional, + dropout_rate, max_timestep, bidirectional_constant, enable_dropout, input, + output, weight_ih_, weight_hh_, bias_h_, bias_ih_, bias_hh_, + hidden_state_, cell_state_, ifgo_, mask_, reverse_weight_ih_, + reverse_weight_hh_, reverse_bias_h_, reverse_bias_ih_, reverse_bias_hh_, + reverse_hidden_state_, reverse_cell_state_, reverse_ifgo_, acti_func, + recurrent_acti_func); + + hidden_state.copyData(hidden_state_); + cell_state.copyData(cell_state_); + ifgo.copyData(ifgo_); + mask.copyData(mask_); + reverse_weight_ih.copyData(reverse_weight_ih_); + reverse_weight_hh.copyData(reverse_weight_hh_); + reverse_bias_h.copyData(reverse_bias_h_); + reverse_bias_ih.copyData(reverse_bias_ih_); + reverse_bias_hh.copyData(reverse_bias_hh_); + reverse_hidden_state.copyData(reverse_hidden_state_); + reverse_cell_state.copyData(reverse_cell_state_); + reverse_ifgo.copyData(reverse_ifgo_); + } else { + forwarding_internal( + this, disable_bias, unit, integrate_bias, return_sequences, bidirectional, + dropout_rate, max_timestep, bidirectional_constant, enable_dropout, input, + output, weight_ih, weight_hh, bias_h, bias_ih, bias_hh, hidden_state, + cell_state, ifgo, mask, reverse_weight_ih, reverse_weight_hh, + reverse_bias_h, reverse_bias_ih, reverse_bias_hh, reverse_hidden_state, + reverse_cell_state, reverse_ifgo, acti_func, recurrent_acti_func); + } +} + +void calcDerivativeInternal(LSTMLayer *layer, bool bidirectional, + Tensor &outgoing_derivative, Tensor &weight_ih, + Tensor &d_ifgos, Tensor &reverse_weight_ih, + Tensor &reverse_d_ifgos) { + layer->calcDerivativeLSTM(outgoing_derivative, weight_ih, d_ifgos); if (bidirectional) { - const Tensor &reverse_weight_ih = - context.getWeight(wt_idx[LSTMParams::reverse_weight_ih]); - const Tensor &reverse_d_ifgos = - context.getTensorGrad(wt_idx[LSTMParams::reverse_ifgo]); + layer->calcDerivativeLSTM(outgoing_derivative, reverse_weight_ih, + reverse_d_ifgos, 1.0f); + } +} + +void LSTMLayer::calcDerivative(RunLayerContext &context) { + const bool bidirectional = std::get(lstm_props).get(); + + Tensor &outgoing_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX); + Tensor &weight_ih = context.getWeight(wt_idx[LSTMParams::weight_ih]); + Tensor &d_ifgos = context.getTensorGrad(wt_idx[LSTMParams::ifgo]); - calcDerivativeLSTM(outgoing_derivative, reverse_weight_ih, reverse_d_ifgos, - 1.0f); + Tensor empty; + empty.setTensorType(outgoing_derivative.getTensorType()); + + Tensor &reverse_weight_ih = + bidirectional ? context.getWeight(wt_idx[LSTMParams::reverse_weight_ih]) + : empty; + Tensor &reverse_d_ifgos = + bidirectional ? context.getTensorGrad(wt_idx[LSTMParams::reverse_ifgo]) + : empty; + + const auto out_type = outgoing_derivative.getDataType(); + if (out_type != weight_ih.getDataType()) { + Tensor weight_ih_ = weight_ih.clone(out_type); + Tensor d_ifgos_ = d_ifgos.clone(out_type); + Tensor reverse_weight_ih_ = reverse_weight_ih.clone(out_type); + Tensor reverse_d_ifgos_ = reverse_d_ifgos.clone(out_type); + + calcDerivativeInternal(this, bidirectional, outgoing_derivative, weight_ih_, + d_ifgos_, reverse_weight_ih_, reverse_d_ifgos_); + + weight_ih.copyData(weight_ih_); + d_ifgos.copyData(d_ifgos_); + reverse_weight_ih.copyData(reverse_weight_ih_); + reverse_d_ifgos.copyData(reverse_d_ifgos_); + } else { + calcDerivativeInternal(this, bidirectional, outgoing_derivative, weight_ih, + d_ifgos, reverse_weight_ih, reverse_d_ifgos); + } +} + +static void calcGradientInternal( + LSTMLayer *layer, const unsigned int batch_size, + const unsigned int feature_size, const bool disable_bias, + const unsigned int unit, const bool integrate_bias, ActiFunc &acti_func, + ActiFunc &recurrent_acti_func, const bool return_sequences, + const bool bidirectional, const bool enable_dropout, const float dropout_rate, + const unsigned int max_timestep, const bool reverse, const Tensor &input, + const Tensor &incoming_derivative, Tensor &d_weight_ih, + const Tensor &weight_hh, Tensor &d_weight_hh, Tensor &d_bias_h, + Tensor &d_bias_ih, Tensor &d_bias_hh, const Tensor &hidden_state, + Tensor &d_hidden_state, const Tensor &cell_state, Tensor &d_cell_state, + const Tensor &ifgo, Tensor &d_ifgo, const Tensor &mask, + Tensor &reverse_d_weight_ih, const Tensor &reverse_weight_hh, + Tensor &reverse_d_weight_hh, Tensor &reverse_d_bias_h, + Tensor &reverse_d_bias_ih, Tensor &reverse_d_bias_hh, + const Tensor &reverse_hidden_state, Tensor &reverse_d_hidden_state, + const Tensor &reverse_cell_state, Tensor &reverse_d_cell_state, + const Tensor &reverse_ifgo, Tensor &reverse_d_ifgo) { + layer->calcGradientBatchFirstLSTM( + layer->NUM_GATE, batch_size, feature_size, disable_bias, unit, + integrate_bias, acti_func, recurrent_acti_func, return_sequences, + bidirectional, enable_dropout, dropout_rate, max_timestep, false, input, + incoming_derivative, d_weight_ih, weight_hh, d_weight_hh, d_bias_h, + d_bias_ih, d_bias_hh, hidden_state, d_hidden_state, cell_state, + d_cell_state, ifgo, d_ifgo, mask); + + if (bidirectional) { + layer->calcGradientBatchFirstLSTM( + layer->NUM_GATE, batch_size, feature_size, disable_bias, unit, + integrate_bias, acti_func, recurrent_acti_func, return_sequences, + bidirectional, enable_dropout, dropout_rate, max_timestep, true, input, + incoming_derivative, reverse_d_weight_ih, reverse_weight_hh, + reverse_d_weight_hh, reverse_d_bias_h, reverse_d_bias_ih, + reverse_d_bias_hh, reverse_hidden_state, reverse_d_hidden_state, + reverse_cell_state, reverse_d_cell_state, reverse_ifgo, reverse_d_ifgo, + mask); } } @@ -864,57 +995,113 @@ void LSTMLayer::calcGradient(RunLayerContext &context) { ? context.getTensor(wt_idx[LSTMParams::dropout_mask]) : empty; - calcGradientBatchFirstLSTM( - NUM_GATE, batch_size, feature_size, disable_bias, unit, integrate_bias, - acti_func, recurrent_acti_func, return_sequences, bidirectional, - enable_dropout, dropout_rate, max_timestep, false, input, - incoming_derivative, d_weight_ih, weight_hh, d_weight_hh, d_bias_h, - d_bias_ih, d_bias_hh, hidden_state, d_hidden_state, cell_state, - d_cell_state, ifgo, d_ifgo, mask); - - if (bidirectional) { - Tensor &reverse_d_weight_ih = - context.getWeightGrad(wt_idx[LSTMParams::reverse_weight_ih]); - const Tensor &reverse_weight_hh = - context.getWeight(wt_idx[LSTMParams::reverse_weight_hh]); - Tensor &reverse_d_weight_hh = - context.getWeightGrad(wt_idx[LSTMParams::reverse_weight_hh]); - Tensor &reverse_d_bias_h = - !disable_bias && integrate_bias - ? context.getWeightGrad(wt_idx[LSTMParams::reverse_bias_h]) - : empty; - Tensor &reverse_d_bias_ih = - !disable_bias && !integrate_bias - ? context.getWeightGrad(wt_idx[LSTMParams::reverse_bias_ih]) - : empty; - Tensor &reverse_d_bias_hh = - !disable_bias && !integrate_bias - ? context.getWeightGrad(wt_idx[LSTMParams::reverse_bias_hh]) - : empty; - - const Tensor &reverse_hidden_state = - context.getTensor(wt_idx[LSTMParams::reverse_hidden_state]); - Tensor &reverse_d_hidden_state = - context.getTensorGrad(wt_idx[LSTMParams::reverse_hidden_state]); - const Tensor &reverse_cell_state = - context.getTensor(wt_idx[LSTMParams::reverse_cell_state]); - Tensor &reverse_d_cell_state = - context.getTensorGrad(wt_idx[LSTMParams::reverse_cell_state]); - - const Tensor &reverse_ifgo = - context.getTensor(wt_idx[LSTMParams::reverse_ifgo]); - Tensor &reverse_d_ifgo = - context.getTensorGrad(wt_idx[LSTMParams::reverse_ifgo]); - - calcGradientBatchFirstLSTM( - NUM_GATE, batch_size, feature_size, disable_bias, unit, integrate_bias, + Tensor &reverse_d_weight_ih = + bidirectional ? context.getWeightGrad(wt_idx[LSTMParams::reverse_weight_ih]) + : empty; + const Tensor &reverse_weight_hh = + bidirectional ? context.getWeight(wt_idx[LSTMParams::reverse_weight_hh]) + : empty; + Tensor &reverse_d_weight_hh = + bidirectional ? context.getWeightGrad(wt_idx[LSTMParams::reverse_weight_hh]) + : empty; + Tensor &reverse_d_bias_h = + !disable_bias && integrate_bias && bidirectional + ? context.getWeightGrad(wt_idx[LSTMParams::reverse_bias_h]) + : empty; + Tensor &reverse_d_bias_ih = + !disable_bias && !integrate_bias && bidirectional + ? context.getWeightGrad(wt_idx[LSTMParams::reverse_bias_ih]) + : empty; + Tensor &reverse_d_bias_hh = + !disable_bias && !integrate_bias && bidirectional + ? context.getWeightGrad(wt_idx[LSTMParams::reverse_bias_hh]) + : empty; + const Tensor &reverse_hidden_state = + bidirectional ? context.getTensor(wt_idx[LSTMParams::reverse_hidden_state]) + : empty; + Tensor &reverse_d_hidden_state = + bidirectional + ? context.getTensorGrad(wt_idx[LSTMParams::reverse_hidden_state]) + : empty; + const Tensor &reverse_cell_state = + bidirectional ? context.getTensor(wt_idx[LSTMParams::reverse_cell_state]) + : empty; + Tensor &reverse_d_cell_state = + bidirectional + ? context.getTensorGrad(wt_idx[LSTMParams::reverse_cell_state]) + : empty; + + const Tensor &reverse_ifgo = + bidirectional ? context.getTensor(wt_idx[LSTMParams::reverse_ifgo]) : empty; + Tensor &reverse_d_ifgo = + bidirectional ? context.getTensorGrad(wt_idx[LSTMParams::reverse_ifgo]) + : empty; + + const auto in_type = input.getDataType(); + if (in_type != d_weight_ih.getDataType()) { + Tensor weight_hh_ = weight_hh.clone(in_type); + Tensor d_weight_ih_ = d_weight_ih.clone(in_type); + Tensor d_weight_hh_ = d_weight_hh.clone(in_type); + Tensor d_bias_h_ = d_bias_h.clone(in_type); + Tensor d_bias_ih_ = d_bias_ih.clone(in_type); + Tensor d_bias_hh_ = d_bias_hh.clone(in_type); + Tensor hidden_state_ = hidden_state.clone(in_type); + Tensor d_hidden_state_ = d_hidden_state.clone(in_type); + Tensor cell_state_ = cell_state.clone(in_type); + Tensor d_cell_state_ = d_cell_state.clone(in_type); + Tensor ifgo_ = ifgo.clone(in_type); + Tensor d_ifgo_ = d_ifgo.clone(in_type); + Tensor mask_ = mask.clone(in_type); + Tensor reverse_d_weight_ih_ = reverse_d_weight_ih.clone(in_type); + Tensor reverse_weight_hh_ = reverse_weight_hh.clone(in_type); + Tensor reverse_d_bias_h_ = reverse_d_bias_h.clone(in_type); + Tensor reverse_d_bias_ih_ = reverse_d_bias_ih.clone(in_type); + Tensor reverse_d_bias_hh_ = reverse_d_bias_hh.clone(in_type); + Tensor reverse_hidden_state_ = reverse_hidden_state.clone(in_type); + Tensor reverse_d_hidden_state_ = reverse_d_hidden_state.clone(in_type); + Tensor reverse_cell_state_ = reverse_cell_state.clone(in_type); + Tensor reverse_d_cell_state_ = reverse_d_cell_state.clone(in_type); + Tensor reverse_ifgo_ = reverse_ifgo.clone(in_type); + Tensor reverse_d_ifgo_ = reverse_d_ifgo.clone(in_type); + + calcGradientInternal( + this, batch_size, feature_size, disable_bias, unit, integrate_bias, acti_func, recurrent_acti_func, return_sequences, bidirectional, - enable_dropout, dropout_rate, max_timestep, true, input, - incoming_derivative, reverse_d_weight_ih, reverse_weight_hh, + enable_dropout, dropout_rate, max_timestep, false, input, + incoming_derivative, d_weight_ih_, weight_hh_, d_weight_hh_, d_bias_h_, + d_bias_ih_, d_bias_hh_, hidden_state_, d_hidden_state_, cell_state_, + d_cell_state_, ifgo_, d_ifgo_, mask_, reverse_d_weight_ih_, + reverse_weight_hh_, reverse_weight_hh_, reverse_d_bias_h_, + reverse_d_bias_ih_, reverse_d_bias_hh_, reverse_hidden_state_, + reverse_d_hidden_state_, reverse_cell_state_, reverse_d_cell_state_, + reverse_ifgo_, reverse_d_ifgo_); + + d_weight_ih.copyData(d_weight_ih_); + d_weight_hh.copyData(d_weight_hh_); + d_hidden_state.copyData(d_hidden_state_); + d_cell_state.copyData(d_cell_state_); + d_bias_h.copyData(d_bias_h_); + d_bias_ih.copyData(d_bias_ih_); + d_bias_hh.copyData(d_bias_hh_); + d_ifgo.copyData(d_ifgo_); + reverse_d_weight_ih.copyData(reverse_d_weight_ih_); + reverse_d_bias_h.copyData(reverse_d_bias_h_); + reverse_d_bias_ih.copyData(reverse_d_bias_ih_); + reverse_d_bias_hh.copyData(reverse_d_bias_hh_); + reverse_d_hidden_state.copyData(reverse_d_hidden_state_); + reverse_d_cell_state.copyData(reverse_d_cell_state_); + reverse_d_ifgo.copyData(reverse_d_ifgo_); + } else { + calcGradientInternal( + this, batch_size, feature_size, disable_bias, unit, integrate_bias, + acti_func, recurrent_acti_func, return_sequences, bidirectional, + enable_dropout, dropout_rate, max_timestep, false, input, + incoming_derivative, d_weight_ih, weight_hh, d_weight_hh, d_bias_h, + d_bias_ih, d_bias_hh, hidden_state, d_hidden_state, cell_state, + d_cell_state, ifgo, d_ifgo, mask, reverse_d_weight_ih, reverse_weight_hh, reverse_d_weight_hh, reverse_d_bias_h, reverse_d_bias_ih, reverse_d_bias_hh, reverse_hidden_state, reverse_d_hidden_state, - reverse_cell_state, reverse_d_cell_state, reverse_ifgo, reverse_d_ifgo, - mask); + reverse_cell_state, reverse_d_cell_state, reverse_ifgo, reverse_d_ifgo); } } diff --git a/nntrainer/layers/lstm.h b/nntrainer/layers/lstm.h index f35fdf8815..a9b2cac7d7 100644 --- a/nntrainer/layers/lstm.h +++ b/nntrainer/layers/lstm.h @@ -99,7 +99,6 @@ class LSTMLayer : public LSTMCore { inline static const std::string type = "lstm"; -private: static constexpr unsigned int NUM_GATE = 4; /** common properties like Unit, IntegrateBias, HiddenStateActivation and diff --git a/nntrainer/layers/pooling2d_layer.cpp b/nntrainer/layers/pooling2d_layer.cpp index 036a933c42..1638da9efe 100644 --- a/nntrainer/layers/pooling2d_layer.cpp +++ b/nntrainer/layers/pooling2d_layer.cpp @@ -26,6 +26,9 @@ namespace nntrainer { static constexpr size_t SINGLE_INOUT_IDX = 0; +/** + * @brief help function for Pooling handler + */ template struct PoolFunc { typedef std::function Type; }; @@ -185,7 +188,7 @@ void Pooling2DLayer::calcDerivative(RunLayerContext &context) { unsigned int out_map_size = deriv.height() * deriv.width(); unsigned int in_map_size = height * width; - auto apply_max = [&](T *result_data) { + auto apply_max = [&](T * result_data) { const int *iter = pool_helper.getData(); const T *deriv_data = deriv.getData(); for (unsigned int b = 0; b < batch; ++b) { @@ -204,7 +207,7 @@ void Pooling2DLayer::calcDerivative(RunLayerContext &context) { } }; - auto apply_average = [&](T *result_data) { + auto apply_average = [&](T * result_data) { int height_stride_end = height - p_height + pt; int width_stride_end = width - p_width + pl; const int *iter = pool_helper.getData(); @@ -236,7 +239,7 @@ void Pooling2DLayer::calcDerivative(RunLayerContext &context) { } }; - auto apply_global_max = [&](T *result_data) { + auto apply_global_max = [&](T * result_data) { const T *deriv_data = deriv.getData(); for (unsigned int b = 0; b < batch; b++) { for (unsigned int c = 0; c < channel; c++) { @@ -258,21 +261,33 @@ void Pooling2DLayer::calcDerivative(RunLayerContext &context) { case props::PoolingTypeInfo::Enum::max: if (in_dim.getDataType() == ml::train::TensorDim::DataType::FP32) apply_max(result.getData()); +#ifdef ENABLE_FP16 else if (in_dim.getDataType() == ml::train::TensorDim::DataType::FP16) apply_max(result.getData<_FP16>()); +#endif + else + throw std::runtime_error("Not supported datatype"); break; case props::PoolingTypeInfo::Enum::global_average: case props::PoolingTypeInfo::Enum::average: if (in_dim.getDataType() == ml::train::TensorDim::DataType::FP32) apply_average(result.getData()); +#ifdef ENABLE_FP16 else if (in_dim.getDataType() == ml::train::TensorDim::DataType::FP16) apply_average(result.getData<_FP16>()); +#endif + else + throw std::runtime_error("Not supported datatype"); break; case props::PoolingTypeInfo::Enum::global_max: if (in_dim.getDataType() == ml::train::TensorDim::DataType::FP32) apply_global_max(result.getData()); +#ifdef ENABLE_FP16 else if (in_dim.getDataType() == ml::train::TensorDim::DataType::FP16) apply_global_max(result.getData<_FP16>()); +#endif + else + throw std::runtime_error("Not supported datatype"); break; default: throw std::runtime_error("Error: Unknown Pooling Type"); @@ -320,7 +335,9 @@ void Pooling2DLayer::pooling2d(Tensor &in, bool training, Tensor &output, * @return result value of pooling */ PoolFunc::Type pool_fn_fp32; +#ifdef ENABLE_FP16 PoolFunc<_FP16>::Type pool_fn_fp16; +#endif unsigned int max_idx_count = 0; @@ -355,9 +372,8 @@ void Pooling2DLayer::pooling2d(Tensor &in, bool training, Tensor &output, return max_val; }; - auto pool_fn_global_max = [&, this](const T *in_data, - int channel_idx, int start_h, - int start_w) { + auto pool_fn_global_max = [&, this ]( + const T *in_data, int channel_idx, int start_h, int start_w) { int end_h = start_h + patch_height; int end_w = start_w + patch_width; @@ -412,16 +428,22 @@ void Pooling2DLayer::pooling2d(Tensor &in, bool training, Tensor &output, switch (pooling_type) { case props::PoolingTypeInfo::Enum::max: pool_fn_fp32 = pool_fn_max; +#ifdef ENABLE_FP16 pool_fn_fp16 = pool_fn_max; +#endif break; case props::PoolingTypeInfo::Enum::global_max: pool_fn_fp32 = pool_fn_global_max; +#ifdef ENABLE_FP16 pool_fn_fp16 = pool_fn_global_max; +#endif break; case props::PoolingTypeInfo::Enum::global_average: case props::PoolingTypeInfo::Enum::average: pool_fn_fp32 = pool_fn_average; +#ifdef ENABLE_FP16 pool_fn_fp16 = pool_fn_average; +#endif break; case props::PoolingTypeInfo::Enum::unknown: default: @@ -447,7 +469,9 @@ void Pooling2DLayer::pooling2d(Tensor &in, bool training, Tensor &output, } } } - } else if (in.getDataType() == ml::train::TensorDim::DataType::FP16) { + } +#ifdef ENABLE_FP16 + else if (in.getDataType() == ml::train::TensorDim::DataType::FP16) { const _FP16 *in_data = in.getData<_FP16>(); _FP16 *out_data = output.getData<_FP16>(); @@ -466,6 +490,10 @@ void Pooling2DLayer::pooling2d(Tensor &in, bool training, Tensor &output, } } } +#endif + else { + throw std::runtime_error("Not supported datatype"); + } } void Pooling2DLayer::setBatch(RunLayerContext &context, unsigned int batch) {