diff --git a/Applications/KNN/jni/meson.build b/Applications/KNN/jni/meson.build index bc50dc0214..58ca099d75 100644 --- a/Applications/KNN/jni/meson.build +++ b/Applications/KNN/jni/meson.build @@ -15,4 +15,4 @@ e = executable('knn_sample', install_dir: application_install_dir ) -test('app_knn', e, args: [nntr_app_resdir / 'KNN']) +test('app_knn', e, args: [nntr_app_resdir / 'KNN/']) diff --git a/api/ccapi/include/model.h b/api/ccapi/include/model.h index e4d3a1bfe1..ef6303e6da 100644 --- a/api/ccapi/include/model.h +++ b/api/ccapi/include/model.h @@ -136,7 +136,7 @@ class Model { * @retval #ML_ERROR_NONE Successful. * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter. */ - virtual int compile() = 0; + virtual int compile(ExecutionMode exec_mode_ = ExecutionMode::TRAIN) = 0; /** * @brief Initialize Network. This should be called after setting the diff --git a/debian/nntrainer-dev.install b/debian/nntrainer-dev.install index 4fd55b3774..11b41f990b 100644 --- a/debian/nntrainer-dev.install +++ b/debian/nntrainer-dev.install @@ -16,6 +16,7 @@ /usr/include/nntrainer/blas_interface.h /usr/include/nntrainer/var_grad.h /usr/include/nntrainer/weight.h +/usr/include/nntrainer/blas_avx.h # todo: update dataset headers /usr/include/nntrainer/databuffer.h /usr/include/nntrainer/databuffer_factory.h diff --git a/meson.build b/meson.build index d4aea330a4..fef811ab97 100644 --- a/meson.build +++ b/meson.build @@ -64,9 +64,21 @@ warning_c_flags = [ '-Wno-error=varargs' ] +arch = host_machine.cpu_family() + +if get_option('enable-avx') and arch == 'x86_64' + extra_defines += '-DUSE_AVX=1' + if get_option('platform') == 'tizen' + add_project_arguments(['-mavx2'], language: ['c','cpp']) + else + add_project_arguments(['-march=native'], language: ['c','cpp']) + endif + message('-march=native added for AVX hardware acceleration.') +elif get_option('enable-avx') + warning('AVX enabled for non x86_64 build target. The enable-avx option is ignored.') +endif if get_option('enable-fp16') - arch = host_machine.cpu_family() if get_option('platform') == 'android' add_project_arguments('-mfp16-format=ieee', language: ['c', 'cpp']) extra_defines += '-DENABLE_FP16=1' @@ -105,11 +117,6 @@ if get_option('enable-fp16') if cc.version().version_compare('>=12.1.0') message ('Float16 for x86_64 enabled. Modern gcc-x64 generally supports float16 with _Float16.') extra_defines += '-DENABLE_FP16=1' - if get_option('enable-avx') - extra_defines += '-DUSE_AVX=1' - add_project_arguments(['-march=native'], language: ['c','cpp']) - message('-march=native added for AVX hardware acceleration.') - endif else warning ('Float16 for x86_64 enabled. However, software emulation is applied for fp16, making it slower and inconsistent. Use GCC 12+ for FP16 support. This build will probably fail unless you bring a compiler that supports fp16 for x64.') endif diff --git a/meson_options.txt b/meson_options.txt index de2578cb47..59accc1c1a 100644 --- a/meson_options.txt +++ b/meson_options.txt @@ -40,7 +40,7 @@ option('enable-fp16', type: 'boolean', value: false) option('enable-cublas', type: 'boolean', value: false) option('enable-openmp', type: 'boolean', value: true) option('enable-neon', type: 'boolean', value: false) -option('enable-avx', type: 'boolean', value: false) +option('enable-avx', type: 'boolean', value: true) option('enable-opencl', type: 'boolean', value: false) # ml-api dependency (to enable, install capi-inference from github.com/nnstreamer/api ) diff --git a/nntrainer/graph/graph_core.cpp b/nntrainer/graph/graph_core.cpp index b624e066e4..3eafbb9261 100644 --- a/nntrainer/graph/graph_core.cpp +++ b/nntrainer/graph/graph_core.cpp @@ -35,6 +35,10 @@ GraphCore::getSortedNode(unsigned int ith) const { return Sorted.at(ith); } +const unsigned int GraphCore::getSortedNodeIdx(const std::string &name) const { + return sorted_node_map.at(name); +} + void GraphCore::makeAdjacencyList( std::vector>> &adj) { /** initialize the adj list */ @@ -93,6 +97,11 @@ void GraphCore::topologicalSort() { if (Sorted.size() != node_list.size()) throw std::runtime_error("Internal error in topologicalSort"); + unsigned int idx = 0; + for (auto n : Sorted) { + sorted_node_map[n->getName()] = idx; + idx++; + } } const std::shared_ptr & diff --git a/nntrainer/graph/graph_core.h b/nntrainer/graph/graph_core.h index 83d3ce7c39..77aa63666a 100644 --- a/nntrainer/graph/graph_core.h +++ b/nntrainer/graph/graph_core.h @@ -91,6 +91,13 @@ class GraphCore { */ const std::shared_ptr &getSortedNode(unsigned int ith) const; + /** + * @brief getter of Sorted GraphNode index with name + * @param[in] layer name + * @ret index + */ + const unsigned int getSortedNodeIdx(const std::string &name) const; + /** * @brief getter of GraphNode with node name * @param[in] node name @@ -252,6 +259,7 @@ class GraphCore { std::vector> node_list; /**< Unordered Node List */ std::unordered_map node_map; /**< Unordered Node map */ + std::unordered_map sorted_node_map; /**< Unordered Node map */ std::vector> Sorted; /**< Ordered Node List */ bool sorted; /** if the node_list is sorted */ diff --git a/nntrainer/graph/network_graph.cpp b/nntrainer/graph/network_graph.cpp index 2d4cfdc769..9ae982c203 100644 --- a/nntrainer/graph/network_graph.cpp +++ b/nntrainer/graph/network_graph.cpp @@ -337,7 +337,7 @@ void NetworkGraph::applyGradients( continue; } - if (rc.isGradientClipByGlobalNorm(i)) { + if (rc.isGradientClipByGlobalNorm(i) || rc.isMixedPrecision(i)) { /** * @note the weights whose gradient are to be clipped by global norm will * be clipped at once at the end of iteration and applied then. @@ -393,56 +393,118 @@ sharedConstTensors NetworkGraph::incremental_forwarding( return out; } -void NetworkGraph::backwarding( +bool NetworkGraph::backwarding( int iteration, - std::function, int)> &backwarding_op, - std::function &apply_grad_clip_op, - std::function stop_cb, void *userdata) const { + std::function, bool)> &forwarding_op, + std::function, int)> &backwarding_op, + std::function &lazy_apply_grad_op, + std::function stop_cb, void *userdata) { /** * last layer backwarding is run out of this loop */ auto iter_begin = getBackwardingBeginIter(); auto iter_end = getBackwardingEndIter(); + bool is_valid = true; /// there is no layer to train, so backwarding is essentially noop if (iter_begin == iter_end) { - return; + return true; } auto const &lptr_begin = (*iter_begin); + // graph_const_reverse_iterator + auto iter_ = iter_begin; if (lptr_begin->requireLabel() == false) throw std::runtime_error( "Error: last layer does not accept label, we can't train"); - for (auto iter = iter_begin; iter != iter_end && !stop_cb(userdata); iter++) { - auto &ln = *iter; + for (iter_ = iter_begin; iter_ != iter_end && !stop_cb(userdata); iter_++) { + auto &ln = *iter_; PROFILE_TIME_START(profile_keys.at(ln->getType())); - backwarding_op(ln, iteration); + is_valid = backwarding_op(ln, iteration); PROFILE_TIME_END(profile_keys.at(ln->getType())); + + if (!is_valid) { + std::cout << ln->getName() << " : Gradient has NaN --> " + << ln->getRunContext().getLossScale() << std::endl; + break; + } } - /** perform clipping of the gradients by global norm if any */ - if (clip_weights.empty()) - return; + if (!is_valid) { + /** if has NaN + * 1. reset the loss scale. : @todo Backoff_factor : default --> 0.5 + * 2. run forwarding from cur_iter to cend() && !stop_cb(userdata); + * 3. return false --> run backwarding again; + */ + float scale = (*iter_)->getRunContext().getLossScale(); + + NNTR_THROW_IF(scale == 1.0f, std::invalid_argument) + << "Loss Scale Factor is 1.0f"; + + float s = scale > 1.5f ? scale * 0.5f : 1.0f; + + resetLossScale(s); - /** calculate the global norm */ - Tensor global_norm_t( - TensorDim({1u, 1u, 1u, (unsigned int)clip_weights.size()})); - float *global_norm_data = global_norm_t.getData(); - for (unsigned int idx = 0; idx < clip_weights.size(); idx++) { - auto const &w = clip_weights[idx]; - global_norm_data[idx] = w->getGradientNorm(); + auto f_iter = cbegin() + graph.getSortedNodeIdx((*iter_)->getName()); + + for (auto iter = f_iter; iter != cend() && !stop_cb(userdata); iter++) { + auto &ln = *iter; + ln->reStoreData(true); + } + + for (auto iter = f_iter; iter != cend() && !stop_cb(userdata); iter++) { + auto &ln = *iter; + PROFILE_TIME_START(profile_keys.at(ln->getType())); + forwarding_op(*iter, true); + PROFILE_TIME_END(profile_keys.at(ln->getType())); + } + + return false; } - float global_norm = global_norm_t.l2norm(); - /** apply the gradient with the above global norm */ - for (auto w : clip_weights) { - w->clipGradientByGlobalNorm(global_norm); + + /** perform clipping of the gradients by global norm if any */ + if (lazy_weights.empty()) + return true; + + if (is_clip_grad) { + /** calculate the global norm */ + Tensor global_norm_t( + TensorDim({1u, 1u, 1u, (unsigned int)lazy_weights.size()})); + float *global_norm_data = global_norm_t.getData(); + for (unsigned int idx = 0; idx < lazy_weights.size(); idx++) { + auto const &w = lazy_weights[idx]; + if (w->getGradientRef().getDataType() != TensorDim::DataType::FP32) { + Tensor grad_32 = w->getGradientRef().clone(TensorDim::DataType::FP32); + global_norm_data[idx] = grad_32.l2norm(); + } else { + global_norm_data[idx] = w->getGradientNorm(); + } + } + float global_norm = global_norm_t.l2norm(); + /** apply the gradient with the above global norm */ + for (auto w : lazy_weights) { + w->clipGradientByGlobalNorm(global_norm); + } } /** apply the gradient with the above global norm */ - for (auto w : clip_weights) { - apply_grad_clip_op(*w, iteration); + for (auto w : lazy_weights) { + lazy_apply_grad_op(*w, iteration); + } + nan_count++; + + /** @todo : handle as property : growth_interval : default --> 2000 */ + + if (nan_count > 2000) { + float scale = (*iter_)->getRunContext().getLossScale(); + /** @todo growth_factor : default --> 2.0 */ + float s = scale * 2.0f; + resetLossScale(s); + nan_count = 0; } + + return true; } LayerNode *NetworkGraph::computeBackwardEnd() { @@ -580,8 +642,15 @@ void NetworkGraph::addLayer(std::shared_ptr layer) { InPlace NetworkGraph::canExecuteInPlace(const std::shared_ptr &lnode) { - if (!lnode->supportInPlace()) + + if (!lnode->supportInPlace()) { return InPlace::NONE; + } + + if (lnode->getType() == InputLayer::type && + !istrequal(getTensorType()[2], "FP32")) { + return InPlace::NONE; + } /** layers which behave as a no-op - flatten */ auto no_op = [](const std::shared_ptr &lnode) { @@ -746,7 +815,7 @@ NetworkGraph::finalizeContext(const std::shared_ptr &lnode, [](const Var_Grad *vg) { return vg->getDim(); }); /** finalize the layer and get the final context */ - auto init_context = lnode->finalize(input_dims, getTensorType()); + auto init_context = lnode->finalize(input_dims, getTensorType(), exec_mode); /** * Request manager for either a pre-allocated output as input or a newly @@ -768,9 +837,10 @@ NetworkGraph::finalizeContext(const std::shared_ptr &lnode, * node is going to be used with in-place optimizations. */ auto out_specs = init_context.getOutSpecs(); + /// @note try move inplace control to finalize bool shared_var = false, shared_grad = false; - if (lnode->executeInPlace() != InPlace::NONE) { + if (lnode->executeInPlace() != InPlace::NONE && lnode->supportInPlace()) { setInplaceSharedMemoryConfigByLayer(lnode, shared_var, shared_grad); for (unsigned int i = 0; i < out_specs.size(); ++i) { auto &s = out_specs.at(i); @@ -873,13 +943,17 @@ NetworkGraph::finalizeContext(const std::shared_ptr &lnode, } } + lnode->setDataType(init_context.getWeightDataType(), + init_context.getActivationDataType()); + lnode->configureRunContext( // TODO: update weights spec for trainable based on layer trainable prop tensor_manager->requestWeights(gnode, init_context.getWeightsSpec(), lnode->getTrainable(), shared_weight_names), inputs, outputs, tensor_manager->requestTensors(gnode, init_context.getTensorsSpec(), - lnode->getTrainable(), shared_tensor_names)); + lnode->getTrainable(), shared_tensor_names), + init_context.getLossScale()); return outputs; } @@ -1027,7 +1101,8 @@ NetworkGraph::refinalizeContext(const std::shared_ptr &lnode, // TODO: update weights spec for trainable based on layer trainable prop weights, inputs, outputs, tensor_manager->requestTensors(gnode, init_context.getTensorsSpec(), - lnode->getTrainable(), shared_tensor_names)); + lnode->getTrainable(), shared_tensor_names), + init_context.getLossScale()); return outputs; } @@ -1197,7 +1272,7 @@ int NetworkGraph::initialize(ExecutionMode mode, */ if (tensor_manager->isLastAccess(rc.getWeightGrad(i).getName(), last_grad_access) || - (rc.isGradientClipByGlobalNorm(i) && + ((rc.isGradientClipByGlobalNorm(i) || rc.isMixedPrecision(i)) && tensor_manager->isSecondLastAccess(rc.getWeightGrad(i).getName(), last_grad_access))) { rc.getWeightObject(i).setAsGradientLastAccess(); @@ -1287,11 +1362,19 @@ int NetworkGraph::initialize(ExecutionMode mode, /** select weights which would require clipping of the gradients by global * norm if any */ - clip_weights = tensor_manager->getWeights([](const Weight *w) { + lazy_weights = tensor_manager->getWeights([](const Weight *w) { return w->hasGradient() && w->isGradientLastAccess() && - w->isGradientClipByGlobalNorm(); + (w->isGradientClipByGlobalNorm() || w->isMixedPrecision()); }); + is_clip_grad = false; + for (auto w : lazy_weights) { + if (w->isGradientClipByGlobalNorm()) { + is_clip_grad = true; + break; + } + } + return ML_ERROR_NONE; } @@ -1556,10 +1639,18 @@ void NetworkGraph::requestOptimizerVariable( const TensorDim &dim = w->getDim(); std::vector dims = cb(dim); w->setOptimizerVariables(tensor_manager->requestWeightOptimizerVariables( - dims, w->getName(), TensorLifespan::MAX_LIFESPAN, - w->isGradientClipByGlobalNorm(), Tensor::Initializer::ZEROS)); + dims, w->getName(), ":opt", TensorLifespan::MAX_LIFESPAN, + w->isGradientClipByGlobalNorm(), w->isMixedPrecision(), + Tensor::Initializer::ZEROS)); } } } +void NetworkGraph::resetLossScale(float scale) { + for (auto iter = cbegin(); iter != cend(); iter++) { + auto &ln = *iter; + ln->getRunContext().setLossScale(scale); + } +} + } /* namespace nntrainer */ diff --git a/nntrainer/graph/network_graph.h b/nntrainer/graph/network_graph.h index 5c9adf0363..867efef323 100644 --- a/nntrainer/graph/network_graph.h +++ b/nntrainer/graph/network_graph.h @@ -51,15 +51,17 @@ class NetworkGraph { optimize_memory(true), exec_mode(ExecutionMode::TRAIN), tensor_format("NCHW"), - tensor_dtype(split("FP32-FP32", getRegex("\\-"))) {} + tensor_dtype(split("FP32-FP32", getRegex("\\-"))) { + nan_count = 0; + } /** * @brief Constructor of NeuralNetwork Graph Class * @param[in] enable_swap enable memory swap for tensor * @param[in] swap_path memory swap file path when the swap is enabled */ - NetworkGraph(bool enable_swap, const std::string &swap_path = "", - unsigned int lookahead = 0, + NetworkGraph(bool enable_swap, ExecutionMode mode = ExecutionMode::TRAIN, + const std::string &swap_path = "", unsigned int lookahead = 0, const std::string &tensor_format_ = "NCHW", const std::string &tensor_dtype_ = "FP32-FP32") : tensor_manager(std::make_shared(enable_swap, swap_path, lookahead, @@ -71,9 +73,11 @@ class NetworkGraph { backward_iter_end(nullptr), forward_iter_end(nullptr), optimize_memory(true), - exec_mode(ExecutionMode::TRAIN), + exec_mode(mode), tensor_format(tensor_format_), - tensor_dtype(split(tensor_dtype_, getRegex("\\-"))) {} + tensor_dtype(split(tensor_dtype_, getRegex("\\-"))) { + nan_count = 0; + } /** * @brief Destructor of the NeuralNetwork Graph class @@ -206,13 +210,14 @@ class NetworkGraph { * @param[in] backwarding_op operation for the backwarding * @param[in] apply_grad_clip_op operation for applying the clip gradients */ - void backwarding( + bool backwarding( int iteration, - std::function, int)> &backwarding_op, - std::function &apply_grad_clip_op, + std::function, bool)> &forwarding_op, + std::function, int)> &backwarding_op, + std::function &lazy_apply_grad_op, std::function stop_cb = [](void *user_data) { return false; }, - void *user_data = nullptr) const; + void *user_data = nullptr); /** * @brief get begin iterator for the graph @@ -322,9 +327,9 @@ class NetworkGraph { * @param lnode layer node to finalize and set run context * @param prev_inputs previous input information */ - std::vector - finalizeContext(const std::shared_ptr &lnode, - const std::vector &prev_inputs); + std::vector finalizeContext( + const std::shared_ptr &lnode, + const std::vector &prev_inputs); /** * @brief Recreate run layer context from the given init layer context @@ -332,9 +337,9 @@ class NetworkGraph { * @param lnode layer node to finalize and set run context * @param prev_inputs previous input information */ - std::vector - refinalizeContext(const std::shared_ptr &lnode, - const std::vector &prev_inputs); + std::vector refinalizeContext( + const std::shared_ptr &lnode, + const std::vector &prev_inputs); /** Interface for manager */ @@ -444,6 +449,12 @@ class NetworkGraph { getLayerExecutionOrders(const std::shared_ptr &lnode); #endif // ENABLE_TEST + /** + * @brief reset the loss scale + * @param[in] scale + */ + void resetLossScale(float scale); + private: std::map sub_in_out; /** This is map to identify input and output layer name of subgraph */ @@ -480,7 +491,10 @@ class NetworkGraph { std::unordered_map profile_keys; /**< profile keys based on the layer type */ std::vector - clip_weights; /**< weights with global norm based clipping enabled */ + lazy_weights; /**< weights with global norm based clipping enabled */ + bool is_clip_grad; + + unsigned int nan_count; /** * @brief topological sort diff --git a/nntrainer/layers/bn_layer.cpp b/nntrainer/layers/bn_layer.cpp index 1723ac677f..dbc9ce0c3b 100644 --- a/nntrainer/layers/bn_layer.cpp +++ b/nntrainer/layers/bn_layer.cpp @@ -38,6 +38,8 @@ enum BNParams { var, gamma, beta, + mu_b, + var_b, deviation, invstd, cvar, @@ -73,6 +75,10 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) { TensorDim dim(context.getFormat(), context.getWeightDataType()); + if (context.getExecutionMode() == ml::train::ExecutionMode::TRAIN) { + dim.setDataType(TensorDim::DataType::FP32); + } + /// @note this logic cannot tell channel is actually 1 or it is just not used. auto &axis_prop = std::get(bn_props); unsigned int axis; @@ -99,26 +105,40 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) { } wt_idx[BNParams::mu] = - context.requestWeight(dim, bnparams_mu, WeightRegularizer::NONE, 1.0f, 0.0f, - "moving_mean", false); + context.requestWeight(dim, dim, bnparams_mu, WeightRegularizer::NONE, 1.0f, + 0.0f, "moving_mean", false); wt_idx[BNParams::var] = - context.requestWeight(dim, bnparams_var, WeightRegularizer::NONE, 1.0f, + context.requestWeight(dim, dim, bnparams_var, WeightRegularizer::NONE, 1.0f, 0.0f, "moving_variance", false); wt_idx[BNParams::gamma] = - context.requestWeight(dim, bnparams_gamma, WeightRegularizer::NONE, 1.0f, - weight_decay, "gamma", true); + context.requestWeight(dim, dim, bnparams_gamma, WeightRegularizer::NONE, + 1.0f, weight_decay, "gamma", true); wt_idx[BNParams::beta] = - context.requestWeight(dim, bnparams_beta, WeightRegularizer::NONE, 1.0f, - bias_decay, "beta", true); + context.requestWeight(dim, dim, bnparams_beta, WeightRegularizer::NONE, + 1.0f, bias_decay, "beta", true); + + wt_idx[BNParams::mu_b] = + context.requestTensor(dim, "moviing_mean_backup", Tensor::Initializer::NONE, + false, TensorLifespan::ITERATION_LIFESPAN); + + wt_idx[BNParams::var_b] = context.requestTensor( + dim, "moviing_variance_backup", Tensor::Initializer::NONE, false, + TensorLifespan::ITERATION_LIFESPAN); /** * caches the deviation -> input - avg(input) * @todo check if avoiding this storage and adding dependency on input (no * more in-place calculation) can save memory during memory optimization. */ + TensorDim in_dim_ = in_dim; + + if (context.getExecutionMode() == ml::train::ExecutionMode::TRAIN) { + in_dim_.setDataType(TensorDim::DataType::FP32); + } + wt_idx[BNParams::deviation] = - context.requestTensor(in_dim, "deviation", Tensor::Initializer::NONE, false, - TensorLifespan::ITERATION_LIFESPAN); + context.requestTensor(in_dim_, "deviation", Tensor::Initializer::NONE, + false, TensorLifespan::ITERATION_LIFESPAN); /** caches the inverse standard deviation */ wt_idx[BNParams::invstd] = context.requestTensor(dim, "invstd", Tensor::Initializer::NONE, false, @@ -130,7 +150,7 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) { * as the output of this layer need not be stored all the time. */ wt_idx[BNParams::t_full] = - context.requestTensor(in_dim, "tensor_full", Tensor::Initializer::NONE, + context.requestTensor(in_dim_, "tensor_full", Tensor::Initializer::NONE, false, TensorLifespan::CALC_DERIV_LIFESPAN); /** * caches variance + epsilon as well. @@ -164,8 +184,32 @@ void BatchNormalizationLayer::forwarding(RunLayerContext &context, Tensor &gamma = context.getWeight(wt_idx[BNParams::gamma]); Tensor &beta = context.getWeight(wt_idx[BNParams::beta]); - Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); - Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); + Tensor em_input, em_hidden; + + Tensor &input_ = em_input; + Tensor &hidden_ = em_hidden; + + if (training) { + if (context.getInput(SINGLE_INOUT_IDX).getDataType() != + TensorDim::DataType::FP32) { + input_ = + context.getInput(SINGLE_INOUT_IDX).clone(TensorDim::DataType::FP32); + } else { + input_ = context.getInput(SINGLE_INOUT_IDX); + } + + if (context.getOutput(SINGLE_INOUT_IDX).getDataType() != + TensorDim::DataType::FP32) { + hidden_ = + context.getOutput(SINGLE_INOUT_IDX).clone(TensorDim::DataType::FP32); + } else { + hidden_ = context.getOutput(SINGLE_INOUT_IDX); + } + } else { + input_ = context.getInput(SINGLE_INOUT_IDX); + hidden_ = context.getOutput(SINGLE_INOUT_IDX); + } + Tensor &deviation = context.getTensor(wt_idx[BNParams::deviation]); Tensor &invstd = context.getTensor(wt_idx[BNParams::invstd]); @@ -176,6 +220,22 @@ void BatchNormalizationLayer::forwarding(RunLayerContext &context, Tensor &cvar = context.getTensor(wt_idx[BNParams::cvar]); if (training) { + + Tensor &mu_b = context.getTensor(wt_idx[BNParams::mu_b]); + Tensor &var_b = context.getTensor(wt_idx[BNParams::var_b]); + + if (context.reStoreData()) { + mu.copyData(mu_b); + var.copyData(var_b); + deviation.setZero(); + invstd.setZero(); + t_reduced.setZero(); + cvar.setZero(); + } else { + mu_b.copyData(mu); + var_b.copyData(var); + } + input_.average(axes_to_reduce, t_reduced); input_.subtract(t_reduced, deviation); @@ -200,13 +260,38 @@ void BatchNormalizationLayer::forwarding(RunLayerContext &context, deviation.multiply(invstd, hidden_); hidden_.multiply_i(gamma); hidden_.add_i(beta); + + if (training && hidden_.getDataType() != + context.getOutput(SINGLE_INOUT_IDX).getDataType()) + context.getOutput(SINGLE_INOUT_IDX).copyData(hidden_); } void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) { Tensor &gamma = context.getWeight(wt_idx[BNParams::gamma]); - const Tensor &deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX); - Tensor &dx = context.getOutgoingDerivative(SINGLE_INOUT_IDX); + + Tensor em_dx, deriv32; + bool deriv_copyed = false; + + const Tensor deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX); + + if (deriv.getDataType() != TensorDim::DataType::FP32) { + deriv_copyed = true; + TensorDim dim = deriv.getDim(); + dim.setDataType(TensorDim::DataType::FP32); + deriv32 = Tensor(dim, true); + deriv32.copyData(deriv); + } + + Tensor &dx = context.getOutgoingDerivative(SINGLE_INOUT_IDX).getDataType() == + TensorDim::DataType::FP32 + ? context.getOutgoingDerivative(SINGLE_INOUT_IDX) + : em_dx; + + if (dx.empty()) + dx = context.getOutgoingDerivative(SINGLE_INOUT_IDX) + .clone(TensorDim::DataType::FP32); + Tensor &deviation = context.getTensor(wt_idx[BNParams::deviation]); Tensor &invstd = context.getTensor(wt_idx[BNParams::invstd]); Tensor &cvar = context.getTensor(wt_idx[BNParams::cvar]); @@ -214,7 +299,9 @@ void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) { Tensor &t_reduced = context.getTensor(wt_idx[BNParams::t_reduced]); Tensor &t_full = context.getTensor(wt_idx[BNParams::t_full]); - deviation.multiply(deriv, t_full); + t_full.setZero(); + + deviation.multiply((deriv_copyed ? deriv32 : deriv), t_full); t_full.average(axes_to_reduce, t_reduced); t_reduced.divide_i(cvar); deviation.multiply_i(t_reduced); @@ -233,22 +320,37 @@ void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) { Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]); dbeta.divide(divider, t_reduced); } else { - deriv.average(axes_to_reduce, t_reduced); + (deriv_copyed ? deriv32 : deriv).average(axes_to_reduce, t_reduced); } - deriv.subtract(t_reduced, dx); + (deriv_copyed ? deriv32 : deriv).subtract(t_reduced, dx); dx.subtract_i(deviation); invstd.multiply_i(gamma); dx.multiply_i(invstd); + + if (dx.getDataType() != + context.getOutgoingDerivative(SINGLE_INOUT_IDX).getDataType()) + context.getOutgoingDerivative(SINGLE_INOUT_IDX).copyData(dx); } void BatchNormalizationLayer::calcGradient(RunLayerContext &context) { /** dgamma is calculated in calcDerivative. dbeta is calculated here */ Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]); - const Tensor &deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX); - deriv.sum(axes_to_reduce, dbeta); + Tensor deriv32; + bool deriv_copyed = false; + + const Tensor deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX); + if (deriv.getDataType() != TensorDim::DataType::FP32) { + deriv_copyed = true; + TensorDim dim = deriv.getDim(); + dim.setDataType(TensorDim::DataType::FP32); + deriv32 = Tensor(dim, true); + deriv32.copyData(deriv); + } + + (deriv_copyed ? deriv32 : deriv).sum(axes_to_reduce, dbeta); } void BatchNormalizationLayer::exportTo( diff --git a/nntrainer/layers/bn_layer.h b/nntrainer/layers/bn_layer.h index 5bce6f471f..f8d611cd9d 100644 --- a/nntrainer/layers/bn_layer.h +++ b/nntrainer/layers/bn_layer.h @@ -124,7 +124,8 @@ class BatchNormalizationLayer : public Layer { float divider; /**< size of the axes of the reduced */ std::vector axes_to_reduce; /**< target axes to reduce */ - std::array wt_idx; /**< indices of the weights and tensors */ + std::array + wt_idx; /**< indices of the weights and tensors */ std::tuple diff --git a/nntrainer/layers/conv2d_layer.cpp b/nntrainer/layers/conv2d_layer.cpp index c059ae9caf..9b52b6913d 100644 --- a/nntrainer/layers/conv2d_layer.cpp +++ b/nntrainer/layers/conv2d_layer.cpp @@ -38,7 +38,8 @@ namespace { static TensorDim calcCol2ImOutputDim(const TensorDim &out, const TensorDim &kdim) { - return TensorDim({kdim.getFeatureLen(), out.width() * out.height()}); + return TensorDim({kdim.getFeatureLen(), out.width() * out.height()}, + out.getTensorType()); } /** @@ -56,7 +57,11 @@ static void col2im(const Tensor &col_matrix, const TensorDim &kdim, const std::array &mstride, const std::array &dilation, Tensor &image) { - auto [pt, pb, pl, pr] = padding; + + auto pt = padding[0]; + auto pb = padding[1]; + auto pl = padding[2]; + auto pr = padding[3]; unsigned k_height = kdim.height(); unsigned k_width = kdim.width(); @@ -84,32 +89,52 @@ static void col2im(const Tensor &col_matrix, const TensorDim &kdim, int h_stride_end = im_eff_height - eff_k_height - pt; int w_stride_end = im_eff_width - eff_k_width - pl; - unsigned col_w = 0; - for (int hs = -pt; hs <= h_stride_end; hs += hstride) { - for (int ws = -pl; ws <= w_stride_end; ws += wstride) { - unsigned col_h = 0; - int patch_height_end = hs + eff_k_height; - int patch_width_end = ws + eff_k_width; - for (unsigned c = 0; c < im_channel; c++) { - for (int h = hs; h < patch_height_end; h += hdilation) { - if (h < 0 || im_height <= h) { - col_h += k_width; - continue; - } - for (int w = ws; w < patch_width_end; w += wdilation) { - if (w < 0 || im_width <= w) { - col_h++; + /** @todo We need to implement way to use this kind of function to work inside + * of Tensor. Then we could remove to access the getData or getValue which has + * dependecy of data type. + */ + auto apply_data = [&](T *val) { + unsigned col_w = 0; + for (int hs = -pt; hs <= h_stride_end; hs += hstride) { + for (int ws = -pl; ws <= w_stride_end; ws += wstride) { + unsigned col_h = 0; + int patch_height_end = hs + eff_k_height; + int patch_width_end = ws + eff_k_width; + for (unsigned c = 0; c < im_channel; c++) { + for (int h = hs; h < patch_height_end; h += hdilation) { + if (h < 0 || im_height <= h) { + col_h += k_width; continue; } - - float *val = image.getAddress(0, c, h, w); - *val += col_matrix.getValue(0, 0, col_h, col_w); - col_h++; + for (int w = ws; w < patch_width_end; w += wdilation) { + if (w < 0 || im_width <= w) { + col_h++; + continue; + } + + val = image.getAddress(0, c, h, w); + *val += col_matrix.getValue(0, 0, col_h, col_w); + col_h++; + } } } + col_w++; } - col_w++; } + }; + + if (image.getDataType() == nntrainer::Tdatatype::FP32) { + float val; + apply_data(&val); + } +#ifdef ENABLE_FP16 + else if (image.getDataType() == nntrainer::Tdatatype::FP16) { + _FP16 val; + apply_data(&val); + } +#endif + else { + throw std::runtime_error("Not supported datatype"); } } @@ -179,7 +204,10 @@ static void im2col(const Tensor &in, const TensorDim &kdim, // } */ - auto [pt, pb, pl, pr] = padding; + auto pt = padding[0]; + auto pb = padding[1]; + auto pl = padding[2]; + auto pr = padding[3]; unsigned int channel = in.channel(); int in_height = in.height(); @@ -198,49 +226,65 @@ static void im2col(const Tensor &in, const TensorDim &kdim, unsigned int out_width = (width - eff_k_width) / mstride[1] + 1; out.reshape( - TensorDim({out_height * out_width, in.channel() * k_height * k_width})); - float *out_data = out.getData(); - - int h_stride_end = height - eff_k_height - pt; - int w_stride_end = width - eff_k_width - pl; - - /// get a patch, size of kernel - /// hs is height_strided, ws is width_strided - unsigned int owidth = out.width(); - unsigned int base_im_w = 0; - for (int hs = -pt; hs <= h_stride_end; hs += mstride[0]) { - unsigned int base_im_h = 0; - int patch_height_end = eff_k_height + hs; - /// map the patch to a single line looping through channel - for (unsigned int c = 0; c < channel; ++c) { - for (int h = hs; h < patch_height_end; h += dilation[0]) { - if (h < 0 || in_height <= h) { - base_im_h += k_width; - continue; - } - - unsigned int im_w = base_im_w; - for (int ws = -pl; ws <= w_stride_end; ws += mstride[1]) { - unsigned int im_h = base_im_h; - int patch_width_end = eff_k_width + ws; + TensorDim({out_height * out_width, in.channel() * k_height * k_width}, + in.getTensorType())); + // float *out_data = out.getData(); + + auto apply_data = [&](T *out_data) { + int h_stride_end = height - eff_k_height - pt; + int w_stride_end = width - eff_k_width - pl; + + /// get a patch, size of kernel + /// hs is height_strided, ws is width_strided + unsigned int owidth = out.width(); + unsigned int base_im_w = 0; + for (int hs = -pt; hs <= h_stride_end; hs += mstride[0]) { + unsigned int base_im_h = 0; + int patch_height_end = eff_k_height + hs; + /// map the patch to a single line looping through channel + for (unsigned int c = 0; c < channel; ++c) { + for (int h = hs; h < patch_height_end; h += dilation[0]) { + if (h < 0 || in_height <= h) { + base_im_h += k_width; + continue; + } - for (int w = ws; w < patch_width_end; w += dilation[1]) { - if (w < 0 || in_width <= w) { + unsigned int im_w = base_im_w; + for (int ws = -pl; ws <= w_stride_end; ws += mstride[1]) { + unsigned int im_h = base_im_h; + int patch_width_end = eff_k_width + ws; + + for (int w = ws; w < patch_width_end; w += dilation[1]) { + if (w < 0 || in_width <= w) { + im_h++; + continue; + } + out_data[im_w * owidth + im_h] = in.getValue(0, c, h, w); im_h++; - continue; } - out_data[im_w * owidth + im_h] = in.getValue(0, c, h, w); - im_h++; + im_w++; } - im_w++; + base_im_h += k_width; } - base_im_h += k_width; } + base_im_w += out_width; } - base_im_w += out_width; + }; + + if (out.getDataType() == nntrainer::Tdatatype::FP32) { + float *out_data = out.getData(); + apply_data(out_data); + } +#ifdef ENABLE_FP16 + else if (out.getDataType() == nntrainer::Tdatatype::FP16) { + _FP16 *out_data = out.getData<_FP16>(); + apply_data(out_data); + } +#endif + else { + throw std::runtime_error("Not supported datatype"); } } - } // namespace enum ConvParams { weight, bias }; @@ -279,9 +323,13 @@ void Conv2DLayer::finalize(InitLayerContext &context) { auto &dilation = std::get>(conv_props); - TensorDim kernel_dim = - TensorDim(filter_size, in_dim.channel(), kernel_size[0], kernel_size[1]); - TensorDim bias_dim = TensorDim(1, filter_size, 1, 1); + auto in_t_type = in_dim.getTensorType(); + in_t_type.data_type = context.getWeightDataType(); + + TensorDim kernel_dim = TensorDim(filter_size, in_dim.channel(), + kernel_size[0], kernel_size[1], in_t_type); + + TensorDim bias_dim = TensorDim(1, filter_size, 1, 1, in_t_type); padding = std::get(conv_props) .compute(in_dim, kernel_dim, {stride[0], stride[1]}, @@ -309,6 +357,9 @@ void Conv2DLayer::finalize(InitLayerContext &context) { out_dim.channel(filter_size); out_dim.height((eff_in_height - eff_k_height) / stride[0] + 1); out_dim.width((eff_in_width - eff_k_width) / stride[1] + 1); + + out_dim.setTensorType(in_dim.getTensorType()); + context.setOutputDimensions({out_dim}); NNTR_THROW_IF(eff_in_height < kernel_size[0] || eff_in_width < kernel_size[1], @@ -379,6 +430,8 @@ void Conv2DLayer::forwarding(RunLayerContext &context, bool training) { TensorDim filter_dim_squeezed{filter_kernel.batch(), filter_kernel.getDim().getFeatureLen()}; + filter_dim_squeezed.setTensorType(filter_kernel.getTensorType()); + filter_kernel.reshape(filter_dim_squeezed); /** diff --git a/nntrainer/layers/dropout.cpp b/nntrainer/layers/dropout.cpp index c00c31d10b..2b146e6e4d 100644 --- a/nntrainer/layers/dropout.cpp +++ b/nntrainer/layers/dropout.cpp @@ -48,7 +48,10 @@ void DropOutLayer::forwarding(RunLayerContext &context, bool training) { /** @todo make this in-place */ if (training && rate_ > epsilon) { Tensor &mask_ = context.getTensor(mask_idx[i]); - mask_.dropout_mask(rate_); + if (!context.reStoreData()) { + mask_.dropout_mask(rate_); + } + input_.multiply(mask_, output_); } else { output_.fill(input_); diff --git a/nntrainer/layers/input_layer.cpp b/nntrainer/layers/input_layer.cpp index eabd40b297..a67701da2c 100644 --- a/nntrainer/layers/input_layer.cpp +++ b/nntrainer/layers/input_layer.cpp @@ -34,7 +34,8 @@ static constexpr size_t SINGLE_INOUT_IDX = 0; InputLayer::InputLayer() : Layer(), - input_props(props::Normalization(), props::Standardization()) {} + input_props(props::Normalization(), props::Standardization()), + is_inplace(true) {} void InputLayer::setProperty(const std::vector &values) { auto remain_props = loadProperties(values, input_props); @@ -47,7 +48,7 @@ void InputLayer::forwarding(RunLayerContext &context, bool training) { Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); if (!context.executeInPlace()) { Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); - hidden_.copy(input_); + hidden_.copyData(input_); } if (std::get(input_props)) @@ -70,7 +71,22 @@ void InputLayer::finalize(InitLayerContext &context) { std::vector output_dims = context.getInputDimensions(); + for (auto &d : output_dims) { + d.setDataType(context.getActivationDataType()); + } + context.setOutputDimensions(output_dims); + + is_inplace = true; + + /** + * @note Input Layer assuems that the FP32 IN Tensor always. Therefore, if the + * activation data type is not fp32, then it does not support in-place + * operation. + */ + if (context.getActivationDataType() != ml::train::TensorDim::DataType::FP32) { + is_inplace = false; + } } } /* namespace nntrainer */ diff --git a/nntrainer/layers/input_layer.h b/nntrainer/layers/input_layer.h index f6728d676b..e9183e23d1 100644 --- a/nntrainer/layers/input_layer.h +++ b/nntrainer/layers/input_layer.h @@ -82,7 +82,7 @@ class InputLayer : public Layer { /** * @copydoc Layer::supportInPlace() */ - bool supportInPlace() const override { return true; } + bool supportInPlace() const override { return is_inplace; } /** * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods @@ -105,6 +105,7 @@ class InputLayer : public Layer { private: std::tuple input_props; + bool is_inplace; }; } // namespace nntrainer diff --git a/nntrainer/layers/layer_context.cpp b/nntrainer/layers/layer_context.cpp index 1a66aed3cd..ce3b5c3fd3 100644 --- a/nntrainer/layers/layer_context.cpp +++ b/nntrainer/layers/layer_context.cpp @@ -38,13 +38,11 @@ static void suffixSpec(VarGradSpecV2 &spec, unsigned int idx) { } } -InitLayerContext::InitLayerContext(const std::vector &dim, - const std::vector &req_out_connected, - bool in_place_, const std::string &n, - const std::string &prefix_, - const float max_norm, - std::array tensor_type_, - const float loss_scale_) : +InitLayerContext::InitLayerContext( + const std::vector &dim, const std::vector &req_out_connected, + bool in_place_, const std::string &n, const std::string &prefix_, + const float max_norm, std::array tensor_type_, + const float loss_scale_, ml::train::ExecutionMode mode_) : input_dim(dim), in_place(in_place_), clip_by_global_norm(max_norm), @@ -53,7 +51,8 @@ InitLayerContext::InitLayerContext(const std::vector &dim, name(n), prefix(prefix_), tensor_type(tensor_type_), - loss_scale(loss_scale_) { + loss_scale(loss_scale_), + mode(mode_) { NNTR_THROW_IF(!validate(), std::invalid_argument) << "Invalid init context name: " << name << " num inputs: " << getNumInputs(); @@ -126,13 +125,15 @@ const std::vector &InitLayerContext::getOutSpecs() const { } RunLayerContext::RunLayerContext(const std::string &name, bool trainable, - float l, bool in_place_, + float l, bool in_place_, float loss_scale_, const std::vector &w, const std::vector &in, const std::vector &out, const std::vector &t) : loss(l), in_place(in_place_), + loss_scale(loss_scale_), + restoreData(false), weights(w), inputs(in), outputs(out), @@ -169,6 +170,16 @@ Tensor &RunLayerContext::getWeightGrad(unsigned int idx) const { return weights[idx]->getGradientRef(); } +/** + * @brief Get the Weight Gradient tensor object + * + * @param idx Identifier of the weight + * @return Tensor& Reference to the weight grad tensor + */ +Tensor &RunLayerContext::getWeightFP32(unsigned int idx) const { + return weights[idx]->getVariableFP32Ref(); +} + /** * @brief Get the Weight Optimizer Variable tensor object * @@ -402,6 +413,17 @@ bool RunLayerContext::isGradientClipByGlobalNorm(unsigned int idx) const { return weights[idx]->isGradientClipByGlobalNorm(); } +bool RunLayerContext::isMixedPrecision(unsigned int idx) const { + return weights[idx]->isMixedPrecision(); +} + +bool RunLayerContext::isMixedPrecision() const { + for (auto w : weights) + if (w->isMixedPrecision()) + return true; + return false; +} + /** * @brief Get the tensor name * diff --git a/nntrainer/layers/layer_context.h b/nntrainer/layers/layer_context.h index 43e9d8eaf8..0bec32c608 100644 --- a/nntrainer/layers/layer_context.h +++ b/nntrainer/layers/layer_context.h @@ -57,13 +57,14 @@ class InitLayerContext { * @param prefix_ prefix * @param max_norm max norm */ - InitLayerContext(const std::vector &dim, - const std::vector &req_out_connected, bool in_place_, - const std::string &n = "", const std::string &prefix_ = "", - const float max_norm = 0.0, - std::array tensor_type_ = {"NCHW", "FP32", - "FP32"}, - const float loss_scale = 0.0); + InitLayerContext( + const std::vector &dim, + const std::vector &req_out_connected, bool in_place_, + const std::string &n = "", const std::string &prefix_ = "", + const float max_norm = 0.0, + std::array tensor_type_ = {"NCHW", "FP32", "FP32"}, + const float loss_scale = 1.0, + ml::train::ExecutionMode mode = ml::train::ExecutionMode::TRAIN); /** * @brief get Tensor Format of Layer * @@ -101,6 +102,14 @@ class InitLayerContext { */ const std::string &getName() const { return name; } + /** + * @brief get Execution Mode + * + * @return Mode Execution Mode : ml::train::ExecutionMode::INFERNECE | + * ml::train::ExecutionMode::TRAIN + */ + const ml::train::ExecutionMode &getExecutionMode() const { return mode; } + /** * @brief Get the number of inputs for the layer * @@ -201,6 +210,35 @@ class InitLayerContext { return weights_spec.size() - 1; } + /** + * @brief Request a new weight for the layer + * + * @param dim dimension of Variable of the weight + * @param dim_g dimension of Gradient of the weight + * @param init initializer for the weight + * @param reg regularizer for the weight + * @param reg_const regularization constant for the weight + * @param name name of the weight + * @param trainable if the weight is trainable (require gradient or not) + * @return unsigned int index of the weight for its getter + * + * @todo Consider providing a guarantee that the returned indices will always + * start from 0 and will always be incremental. + */ + unsigned int requestWeight(const TensorDim &dim, const TensorDim &dim_g, + const Tensor::Initializer init, + const WeightRegularizer reg, const float reg_const, + const float decay, const std::string &name, + bool trainable = true, unsigned int out_axis = 3) { + + /** @note : We assumes the gradient type is same with Activation data + * type.*/ + weights_spec.emplace_back(dim, dim_g, init, reg, reg_const, decay, + clip_by_global_norm, trainable, + prefix + ":" + name, out_axis, loss_scale); + return weights_spec.size() - 1; + } + /** * @brief Request a new weight for the layer * @@ -348,6 +386,14 @@ class InitLayerContext { */ bool executeInPlace() const { return in_place; } + /** + * @brief get Initial value of Loss_Scale. This is set to RunLayerContext + * and updated + * + * @return loss_scale + */ + float getLossScale() const { return loss_scale; } + private: std::vector input_dim; /**< Input dimensions for the layer */ bool in_place; /**< if the layer is expected to run in-place */ @@ -365,6 +411,7 @@ class InitLayerContext { std::string prefix; /**< prefix of the layer */ std::array tensor_type; float loss_scale; /**< loss_scale value */ + ml::train::ExecutionMode mode; }; /** @@ -385,7 +432,8 @@ class RunLayerContext { * @brief Construct a new Run Layer Context object * */ - RunLayerContext() : loss(0.0), in_place(false) {} + RunLayerContext() : + loss(0.0), in_place(false), loss_scale(1.0), restoreData(false) {} /** * @brief Construct a new Run Layer Context object @@ -396,6 +444,17 @@ class RunLayerContext { std::get(props).set(name); } + /** + * @brief Construct a new Run Layer Context object + * + */ + RunLayerContext(const std::string &name, bool in_place_, float loss_scale_) : + RunLayerContext() { + in_place = in_place_; + std::get(props).set(name); + loss_scale = loss_scale_; + } + /** * @brief Construct a new Run Layer Context object * @@ -403,13 +462,15 @@ class RunLayerContext { * @param trainable if the layer is trainable * @param l loss of the layer * @param in_place_ execution in-place of the layer + * @param loss_scale loss_scale of the layer * @param w weights of the layer * @param in inputs of the layer * @param out outputs of the layer * @param t extra tensors of the layer */ RunLayerContext(const std::string &name, bool trainable, float l, - bool in_place_, const std::vector &w, + bool in_place_, float loss_scale_, + const std::vector &w, const std::vector &in, const std::vector &out, const std::vector &t); @@ -463,6 +524,15 @@ class RunLayerContext { Tensor &getWeightGrad(unsigned int idx) const; /** + * @brief Get the Weight Gradient tensor object + * + * @param idx Identifier of the weight + * @return Tensor& Reference to the weight grad tensor + */ + Tensor &getWeightFP32(unsigned int idx) const; + + /** + * @brief Get the Weight Optimizer Variable tensor object * * @param idx Identifier of the weight @@ -659,6 +729,20 @@ class RunLayerContext { */ bool isGradientClipByGlobalNorm(unsigned int idx) const; + /** + * @brief check if the weight is mixed precsion + * + * @param idx index + * @return bool true if it is mixed precision + */ + bool isMixedPrecision(unsigned int idx) const; + + /** + * @brief check if the weight is mixed precsion + * @return bool true if it is mixed precision + */ + bool isMixedPrecision() const; + /** * @brief Get the tensor name * @@ -878,10 +962,42 @@ class RunLayerContext { */ ml::train::LayerComputeEngine getComputeEngine() { return compute_engine; } + /** + * @brief get loss scale + * @return loss scale + */ + float getLossScale() { return loss_scale; } + + /** + * @brief set Loss_Scale. + * + * @return loss_scale + */ + void setLossScale(float scale) { + loss_scale = scale; + for (auto w : weights) { + w->setLossScale(scale); + } + } + + /** + * @brief set Output Zero Flag. + * + */ + void reStoreData(bool nb) { restoreData = nb; } + + /** + * @brief get Output Zero Flag. + * + */ + bool reStoreData() { return restoreData; } + private: std::tuple props; /**< props of the layer */ float loss; /**< loss of the layer */ - bool in_place; /**< if the layer is expected to run in-place */ + bool in_place; /**< if the layer is expected to run in-place */ + float loss_scale; /**< loss_scale of the layer */ + bool restoreData; /**< reset output for mixed precsion */ std::vector weights; /**< weights of the layer */ std::vector inputs; /**< inputs of the layer */ diff --git a/nntrainer/layers/layer_node.cpp b/nntrainer/layers/layer_node.cpp index 36563b6570..94a41cac98 100644 --- a/nntrainer/layers/layer_node.cpp +++ b/nntrainer/layers/layer_node.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -188,6 +189,7 @@ LayerNode::LayerNode(std::unique_ptr &&l) : inplace(InPlace::NONE), needs_calc_derivative(false), needs_calc_gradient(false), + output_connections(), run_context(nullptr), layer_node_props( @@ -198,7 +200,8 @@ LayerNode::LayerNode(std::unique_ptr &&l) : new RealizationPropsType(props::Flatten(), props::Activation())), loss(new props::Loss()), regularization_loss(0.0f), - exec_order({0, 0, 0, 0}) { + exec_order({0, 0, 0, 0}), + needs_restore_data(false) { if (layer && layer->getType() == TimeDistLayer::type) { std::get(*layer_node_props).set(true); } @@ -468,9 +471,11 @@ void LayerNode::exportTo(Exporter &exporter, layer->exportTo(exporter, method); } -void LayerNode::read(std::ifstream &file, bool opt_var) { +void LayerNode::read(std::ifstream &file, bool opt_var, + ml::train::ExecutionMode mode) { NNTR_THROW_IF(!run_context, std::runtime_error) << __func__ << " layer needs to be finalized first!"; + if (opt_var) { for (unsigned int i = 0; i < run_context->getNumWeights(); ++i) { if (run_context->isGradientLastAccess(i) && getTrainable()) { @@ -481,16 +486,41 @@ void LayerNode::read(std::ifstream &file, bool opt_var) { } } } else { + for (unsigned int i = 0; i < run_context->getNumWeights(); ++i) { /// @note shared weights are only be read at the first acecss if (run_context->isGradientLastAccess(i)) { - run_context->getWeight(i).read(file); + if (layer->getType() == BatchNormalizationLayer::type) { + if ((mode == ml::train::ExecutionMode::TRAIN) && + (this->getWeightDataType() != TensorDim::DataType::FP32)) { + + /** @note for batch normalization layer, we do need full precision + * for training. but weight can be saved with other type. for + * training, bn weight type is fixed with full precsion */ + + TensorDim dim = run_context->getWeight(i).getDim(); + dim.setDataType(this->getWeightDataType()); + Tensor T_read(dim, true); + T_read.read(file); + run_context->getWeight(i).copyData(T_read); + } else { + run_context->getWeight(i).read(file); + } + } else { + run_context->getWeight(i).read(file); + } + + if (run_context->isMixedPrecision(i) && getTrainable() && + !run_context->getWeightFP32(i).empty()) { + run_context->getWeightFP32(i).copyData(run_context->getWeight(i)); + } } } } } -void LayerNode::save(std::ofstream &file, bool opt_var) const { +void LayerNode::save(std::ofstream &file, bool opt_var, + ml::train::ExecutionMode mode) const { NNTR_THROW_IF(!run_context, std::runtime_error) << __func__ << " layer needs to be finalized first!"; @@ -510,7 +540,29 @@ void LayerNode::save(std::ofstream &file, bool opt_var) const { // @note shared weights are only be saved at the first access for (unsigned int i = 0; i < run_context->getNumWeights(); ++i) { if (run_context->isGradientLastAccess(i)) { - run_context->getWeight(i).save(file); + + /** @note For batch normalization layer, we do need full precision for + * training and the data type of weight is full precision. But for + * inference, We do have to save them as activation data type. */ + + if (layer->getType() == BatchNormalizationLayer::type) { + if ((mode == ml::train::ExecutionMode::TRAIN) && + (this->getWeightDataType() != TensorDim::DataType::FP32)) { + TensorDim dim = run_context->getWeight(i).getDim(); + + dim.setDataType(this->getWeightDataType()); + + Tensor T_save(dim, true); + + T_save.copyData(run_context->getWeight(i)); + + T_save.save(file); + } else { + run_context->getWeight(i).save(file); + } + } else { + run_context->getWeight(i).save(file); + } } } } @@ -533,7 +585,8 @@ void LayerNode::clearOptVar() { * @brief Finalize creating the layer node */ InitLayerContext LayerNode::finalize(const std::vector &input_dims, - std::array tensor_type) { + std::array tensor_type, + ml::train::ExecutionMode mode) { // auto get_tensor_datatype = [](const std::string ty) -> TensorDim::DataType // { return from_string(ty); // }; @@ -609,7 +662,7 @@ InitLayerContext LayerNode::finalize(const std::vector &input_dims, const auto &scope = getSharedFrom().empty() ? getName() : getSharedFrom(); float max_norm = 0.0; - float loss_scale = 0.0; + float loss_scale = 1.0; if (!std::get(*layer_node_props).empty()) max_norm = std::get(*layer_node_props).get(); @@ -635,9 +688,9 @@ InitLayerContext LayerNode::finalize(const std::vector &input_dims, out_info.push_back(true); } - auto context = InitLayerContext(actual_input_dims, out_info, - executeInPlace() != InPlace::NONE, getName(), - scope, max_norm, tensor_type, loss_scale); + auto context = InitLayerContext( + actual_input_dims, out_info, executeInPlace() != InPlace::NONE, getName(), + scope, max_norm, tensor_type, loss_scale, mode); layer->finalize(context); @@ -758,8 +811,23 @@ LayerNode::refinalize(const std::vector &input_dims) { */ void LayerNode::forwarding(bool training) { loss->set(run_context->getRegularizationLoss()); + PROFILE_TIME_START(forward_event_key); + if (reStoreData()) { + if (executeInPlace() == InPlace::NONE) { + for (unsigned int i = 0; i < run_context->getNumOutputs(); ++i) { + run_context->getOutput(i).setValue(0); + } + for (unsigned int i = 0; i < run_context->getNumWeights(); ++i) { + if (run_context->weightHasGradient(i)) { + run_context->getWeightGrad(i).setValue(0); + } + } + } + } + layer->forwarding(*run_context, training); + reStoreData(false); PROFILE_TIME_END(forward_event_key); TRACE_MEMORY() << getName() + ": F"; TRACE_TIME() << getName() + ": F"; @@ -874,10 +942,11 @@ float LayerNode::getLoss() const { return *loss; } void LayerNode::configureRunContext(const std::vector &weights, const std::vector &inputs, const std::vector &outputs, - const std::vector &tensors) { + const std::vector &tensors, + float loss_scale) { run_context = std::make_unique( - getName(), getTrainable(), 0.0f, executeInPlace() != InPlace::NONE, weights, - inputs, outputs, tensors); + getName(), getTrainable(), 0.0f, executeInPlace() != InPlace::NONE, + loss_scale, weights, inputs, outputs, tensors); } /** diff --git a/nntrainer/layers/layer_node.h b/nntrainer/layers/layer_node.h index f373386605..5e4e5dfa96 100644 --- a/nntrainer/layers/layer_node.h +++ b/nntrainer/layers/layer_node.h @@ -131,6 +131,32 @@ class LayerNode final : public ml::train::Layer, public GraphNode { setProperty({"name=" + name}); } + /** + * @brief set weight and activation data type of layer + * + * @param[in] weight data type, activation data type + */ + void setDataType(const TensorDim::DataType w_type, + const TensorDim::DataType a_type) { + data_type = {w_type, a_type}; + } + + /** + * @brief Get the Weight Data Type + * + * @return TensorDim::DataType weight data type + */ + const TensorDim::DataType getWeightDataType() const { return data_type[0]; } + + /** + * @brief Get the Activation Data Type + * + * @return TensorDim::DataType activation data type + */ + const TensorDim::DataType getActivationDataType() const { + return data_type[1]; + } + /** * @brief Get the Input Connection Index object * @@ -253,9 +279,10 @@ class LayerNode final : public ml::train::Layer, public GraphNode { * will be made available during execution of the layer with the context. * @note configureRunContext() is expected to called right after this. */ - InitLayerContext finalize(const std::vector &input_dims = {}, - std::array tensor_type = { - "NCHW", "FP32", "FP32"}); + InitLayerContext + finalize(const std::vector &input_dims = {}, + std::array tensor_type = {"NCHW", "FP32", "FP32"}, + ml::train::ExecutionMode mode = ml::train::ExecutionMode::TRAIN); /** * @brief Refinalize creating the layer node @@ -487,6 +514,7 @@ class LayerNode final : public ml::train::Layer, public GraphNode { const std::vector getOutputDimensions() const; /** * @brief Get the Weight object + * currently, only unittest uses this func. * * @param idx Identifier of the weight * @return Weight& Reference to the weight @@ -495,11 +523,11 @@ class LayerNode final : public ml::train::Layer, public GraphNode { NNTR_THROW_IF(!run_context, std::runtime_error) << __func__ << " layer needs to be finalized first!"; if (run_context->weightHasGradient(idx)) { - return Weight(run_context->getWeight(idx), - run_context->getWeightGrad(idx), - run_context->getWeightName(idx)); + return Weight( + run_context->getWeight(idx), run_context->getWeightGrad(idx), + run_context->getWeightFP32(idx), run_context->getWeightName(idx)); } else { - return Weight(run_context->getWeight(idx), Tensor(), + return Weight(run_context->getWeight(idx), Tensor(), Tensor(), run_context->getWeightName(idx)); } } @@ -718,14 +746,17 @@ class LayerNode final : public ml::train::Layer, public GraphNode { * @param file input file stream * @param bool read optimizer variables */ - void read(std::ifstream &file, bool opt_var = false); + void read(std::ifstream &file, bool opt_var = false, + ml::train::ExecutionMode mode = ml::train::ExecutionMode::TRAIN); /** * @brief save layer Weight & Bias data from file * @param file output file stream * @param bool save optimizer variables */ - void save(std::ofstream &file, bool opt_var = false) const; + void + save(std::ofstream &file, bool opt_var = false, + ml::train::ExecutionMode mode = ml::train::ExecutionMode::TRAIN) const; /** * @brief clear optimizer variable to initial state @@ -819,7 +850,8 @@ class LayerNode final : public ml::train::Layer, public GraphNode { void configureRunContext(const std::vector &weights, const std::vector &inputs, const std::vector &outputs, - const std::vector &tensors); + const std::vector &tensors, + float loss_scale); /** * @brief Preset modes for printing summary for the layer @@ -877,6 +909,16 @@ class LayerNode final : public ml::train::Layer, public GraphNode { needs_calc_derivative = nb; } + /** + * @brief Set if the layer output needs reinitialization @mixed precsion + * + * @param nb true if the layer needs to do reinitialization, eles false + */ + void reStoreData(bool nb) { + needs_restore_data = nb; + run_context->reStoreData(nb); + } + /** * @brief Set if the layer needs to do calculation of gradients * @@ -898,6 +940,13 @@ class LayerNode final : public ml::train::Layer, public GraphNode { */ bool needsCalcGradient() { return needs_calc_gradient; } + /** + * @brief Set if the layer needs to reinitialization @mixed precsion + * + * @param nb true if the layer needs reinitialization, eles false + */ + bool reStoreData() { return needs_restore_data; } + private: /** * @brief Get the Input Layers object @@ -971,6 +1020,11 @@ properties in the context/graph unless intended. */ ExecutionOrder exec_order; /**< order/location of execution for this node in forward and backwarding operations */ + bool needs_restore_data; /**< cache if this layer needs reinitialization + output */ + + std::array data_type; + /** * @brief Get the effective layer managed by this layer node * diff --git a/nntrainer/layers/loss/loss_layer.cpp b/nntrainer/layers/loss/loss_layer.cpp index 40f74717f8..8d18878f49 100644 --- a/nntrainer/layers/loss/loss_layer.cpp +++ b/nntrainer/layers/loss/loss_layer.cpp @@ -22,8 +22,12 @@ void LossLayer::finalize(InitLayerContext &context) { d.setDataType( str_converter::from_string("FP32")); - + context.setOutputDimensions(output_dim); + + is_inplace = true; + if (context.getActivationDataType() != ml::train::TensorDim::DataType::FP32) + is_inplace = false; } void LossLayer::updateLoss(RunLayerContext &context, const Tensor &l) { @@ -36,6 +40,13 @@ void LossLayer::updateLoss(RunLayerContext &context, const Tensor &l) { context.setLoss(loss_sum / (float)l.batch()); } +void LossLayer::applyLossScale(RunLayerContext &context, Tensor &ret_deriv) { + + float loss_scale = context.getLossScale(); + if (loss_scale != 1.0) + ret_deriv.multiply_i(loss_scale); +} + /** * @copydoc Layer::setProperty(const std::vector &values) */ diff --git a/nntrainer/layers/loss/loss_layer.h b/nntrainer/layers/loss/loss_layer.h index 00b520f6e6..418777606c 100644 --- a/nntrainer/layers/loss/loss_layer.h +++ b/nntrainer/layers/loss/loss_layer.h @@ -47,6 +47,8 @@ class LossLayer : public Layer { */ virtual bool supportBackwarding() const override { return true; } + bool supportInPlace() const override {return is_inplace;} + /** * @copydoc Layer::requireLabel() */ @@ -60,8 +62,17 @@ class LossLayer : public Layer { */ void updateLoss(RunLayerContext &context, const Tensor &l); + /** + * @brief update return derivative with loss scale + * @param context Run context to update + * @param return_dev Tensor data to calculate + */ + void applyLossScale(RunLayerContext &context, Tensor &l); + Tensor l; /**< loss tensor to store intermediate value to calculate loss value */ + + bool is_inplace; }; } // namespace nntrainer diff --git a/nntrainer/layers/loss/mse_loss_layer.cpp b/nntrainer/layers/loss/mse_loss_layer.cpp index 7f7bd1626f..356acae6f5 100644 --- a/nntrainer/layers/loss/mse_loss_layer.cpp +++ b/nntrainer/layers/loss/mse_loss_layer.cpp @@ -20,7 +20,16 @@ static constexpr size_t SINGLE_INOUT_IDX = 0; void MSELossLayer::forwarding(RunLayerContext &context, bool training) { Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); - Tensor &y = context.getInput(SINGLE_INOUT_IDX); + + Tensor empty_tensor; + Tensor &y = context.getInput(SINGLE_INOUT_IDX).getDataType() == + ml::train::TensorDim::DataType::FP32 + ? context.getInput(SINGLE_INOUT_IDX) + : empty_tensor; + + if (y.empty()) + y = context.getInput(SINGLE_INOUT_IDX) + .clone(ml::train::TensorDim::DataType::FP32); // hidden_ <- y2 - y; if (context.isLabelAvailable(SINGLE_INOUT_IDX)) { @@ -41,9 +50,28 @@ void MSELossLayer::forwarding(RunLayerContext &context, bool training) { } void MSELossLayer::calcDerivative(RunLayerContext &context) { - Tensor &ret_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX); + Tensor empty_tensor; + + Tensor &ret_derivative = + context.getOutgoingDerivative(SINGLE_INOUT_IDX).getDataType() == + ml::train::TensorDim::DataType::FP32 + ? context.getOutgoingDerivative(SINGLE_INOUT_IDX) + : empty_tensor; + + if (ret_derivative.empty()) + ret_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX) + .clone(ml::train::TensorDim::DataType::FP32); + Tensor empty_tensor1; + Tensor &y = context.getInput(SINGLE_INOUT_IDX).getDataType() == + ml::train::TensorDim::DataType::FP32 + ? context.getInput(SINGLE_INOUT_IDX) + : empty_tensor1; + + if (y.empty()) + y = context.getInput(SINGLE_INOUT_IDX) + .clone(ml::train::TensorDim::DataType::FP32); + const Tensor &y2 = context.getIncomingDerivative(SINGLE_INOUT_IDX); - Tensor &y = context.getInput(SINGLE_INOUT_IDX); y.subtract(y2, ret_derivative); float divider = ((float)y.size()) / 2; @@ -51,6 +79,16 @@ void MSELossLayer::calcDerivative(RunLayerContext &context) { throw std::runtime_error( "[MSELossLayer::calcDerivative] Error when calculating loss"); } + + // Loss Scale needs Full precsiion of ret_derivative. Therefore, + // ret_derivateive should be FP32 when applying scale, and after applying it + // need to convert original type for backpropagating. + + LossLayer::applyLossScale(context, ret_derivative); + + if (context.getOutgoingDerivative(SINGLE_INOUT_IDX).getDataType() != + ml::train::TensorDim::DataType::FP32) + context.getOutgoingDerivative(SINGLE_INOUT_IDX).copyData(ret_derivative); } } // namespace nntrainer diff --git a/nntrainer/layers/loss/mse_loss_layer.h b/nntrainer/layers/loss/mse_loss_layer.h index 387e92b3b5..829b921668 100644 --- a/nntrainer/layers/loss/mse_loss_layer.h +++ b/nntrainer/layers/loss/mse_loss_layer.h @@ -51,6 +51,7 @@ class MSELossLayer : public LossLayer { const std::string getType() const override { return MSELossLayer::type; }; inline static const std::string type = "mse"; + }; } // namespace nntrainer diff --git a/nntrainer/layers/lstm.cpp b/nntrainer/layers/lstm.cpp index d5f13a1fc5..591722dc62 100644 --- a/nntrainer/layers/lstm.cpp +++ b/nntrainer/layers/lstm.cpp @@ -463,6 +463,7 @@ void LSTMLayer::finalize(InitLayerContext &context) { // bidirectional ? 2 * unit : unit ] TensorDim::TensorType activation_tensor_type = { context.getFormat(), context.getActivationDataType()}; + TensorDim::TensorType weight_tensor_type = {context.getFormat(), context.getWeightDataType()}; const TensorDim output_dim(batch_size, 1, return_sequences ? max_timestep : 1, @@ -510,20 +511,23 @@ void LSTMLayer::finalize(InitLayerContext &context) { // hidden_state_dim : [ batch_size, 1, max_timestep, unit ] const TensorDim hidden_state_dim(batch_size, 1, max_timestep, unit, - weight_tensor_type); + activation_tensor_type); + wt_idx[LSTMParams::hidden_state] = context.requestTensor( hidden_state_dim, "hidden_state", Tensor::Initializer::NONE, true, TensorLifespan::ITERATION_LIFESPAN); // cell_state_dim : [ batch_size, 1, max_timestep, unit ] const TensorDim cell_state_dim(batch_size, 1, max_timestep, unit, - weight_tensor_type); + activation_tensor_type); + wt_idx[LSTMParams::cell_state] = context.requestTensor( cell_state_dim, "cell_state", Tensor::Initializer::NONE, true, TensorLifespan::ITERATION_LIFESPAN); // ifgo_dim : [ batch_size, 1, max_timestep, NUM_GATE * unit ] const TensorDim ifgo_dim(batch_size, 1, max_timestep, NUM_GATE * unit, - weight_tensor_type); + activation_tensor_type); + wt_idx[LSTMParams::ifgo] = context.requestTensor(ifgo_dim, "ifgo", Tensor::Initializer::NONE, true, TensorLifespan::ITERATION_LIFESPAN); @@ -577,20 +581,20 @@ void LSTMLayer::finalize(InitLayerContext &context) { // reverse_hidden_state_dim : [ batch_size, 1, max_timestep, unit ] const TensorDim reverse_hidden_state_dim(batch_size, 1, max_timestep, unit, - weight_tensor_type); + activation_tensor_type); wt_idx[LSTMParams::reverse_hidden_state] = context.requestTensor( reverse_hidden_state_dim, "reverse_hidden_state", Tensor::Initializer::NONE, true, TensorLifespan::ITERATION_LIFESPAN); // reverse_cell_state_dim : [ batch_size, 1, max_timestep, unit ] const TensorDim reverse_cell_state_dim(batch_size, 1, max_timestep, unit, - weight_tensor_type); + activation_tensor_type); wt_idx[LSTMParams::reverse_cell_state] = context.requestTensor( reverse_cell_state_dim, "reverse_cell_state", Tensor::Initializer::NONE, true, TensorLifespan::ITERATION_LIFESPAN); // reverse_ifgo_dim : [ batch_size, 1, max_timestep, NUM_GATE * unit ] const TensorDim reverse_ifgo_dim(batch_size, 1, max_timestep, - NUM_GATE * unit, weight_tensor_type); + NUM_GATE * unit, activation_tensor_type); wt_idx[LSTMParams::reverse_ifgo] = context.requestTensor( reverse_ifgo_dim, "reverse_ifgo", Tensor::Initializer::NONE, true, TensorLifespan::ITERATION_LIFESPAN); @@ -599,7 +603,7 @@ void LSTMLayer::finalize(InitLayerContext &context) { if (dropout_rate > epsilon) { // dropout_mask_dim = [ batch, 1, time_iteration, unit ] const TensorDim dropout_mask_dim(batch_size, 1, max_timestep, unit, - weight_tensor_type); + activation_tensor_type); wt_idx[LSTMParams::dropout_mask] = context.requestTensor( dropout_mask_dim, "dropout_mask", Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN); diff --git a/nntrainer/layers/pooling2d_layer.cpp b/nntrainer/layers/pooling2d_layer.cpp index a68e42e8d0..676a0a6128 100644 --- a/nntrainer/layers/pooling2d_layer.cpp +++ b/nntrainer/layers/pooling2d_layer.cpp @@ -6,6 +6,8 @@ * @date 12 June 2020 * @see https://github.com/nnstreamer/nntrainer * @author Jijoong Moon + * @author Donghak Park + * @author Jiho Chu * @bug No known bugs except for NYI items * @brief This is 2 Dimensional Pooling Layer Class for Neural Network * @@ -26,6 +28,13 @@ namespace nntrainer { static constexpr size_t SINGLE_INOUT_IDX = 0; +/** + * @brief Help function for Pooling Handler + */ +template struct PoolFunc { + typedef std::function Type; +}; + Pooling2DLayer::Pooling2DLayer( const std::array &padding_) : Layer(), @@ -73,7 +82,9 @@ void Pooling2DLayer::finalize(InitLayerContext &context) { NNTR_THROW_IF(pt + pb + pl + pr != 0, std::invalid_argument) << "[Pooling2D] global_max, global_average does not accept padding"; - NNTR_THROW_IF(stride[0] != 1 || stride[1] != 1, std::invalid_argument) + NNTR_THROW_IF(static_cast(stride[0]) != 1 || + static_cast(stride[1]) != 1, + std::invalid_argument) << "[Pooling2D] global_max, global_average does not accept stride"; } @@ -96,6 +107,7 @@ void Pooling2DLayer::finalize(InitLayerContext &context) { out_dim.channel(in_dim.channel()); out_dim.height((eff_in_height - pool_size[0]) / stride[0] + 1); out_dim.width((eff_in_width - pool_size[1]) / stride[1] + 1); + out_dim.setDataType(in_dim.getDataType()); context.setOutputDimensions({out_dim}); /** @@ -111,13 +123,17 @@ void Pooling2DLayer::finalize(InitLayerContext &context) { * // clang-format on */ if (pooling_type == props::PoolingTypeInfo::Enum::global_max) { + auto helper_dim = in_dim; + helper_dim.setDataType(ml::train::TensorDim::DataType::FP32); pool_helper_idx = - context.requestTensor(in_dim, "helper_idx", Tensor::Initializer::NONE, + context.requestTensor(helper_dim, "helper_idx", Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN); - pool_helper_size.resize(in_dim.batch() * in_dim.channel()); + pool_helper_size.resize(helper_dim.batch() * helper_dim.channel()); } else { + auto helper_dim = out_dim; + helper_dim.setDataType(ml::train::TensorDim::DataType::FP32); pool_helper_idx = - context.requestTensor(out_dim, "helper_idx", Tensor::Initializer::NONE, + context.requestTensor(helper_dim, "helper_idx", Tensor::Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN); } } @@ -172,15 +188,12 @@ void Pooling2DLayer::calcDerivative(RunLayerContext &context) { unsigned int J, K; result.setZero(); - float *result_data = result.getData(); unsigned int out_map_size = deriv.height() * deriv.width(); unsigned int in_map_size = height * width; - - switch (pooling_type) { - case props::PoolingTypeInfo::Enum::max: { + auto apply_max = [&](T *result_data) { const int *iter = pool_helper.getData(); - const float *deriv_data = deriv.getData(); + const T *deriv_data = deriv.getData(); for (unsigned int b = 0; b < batch; ++b) { for (unsigned int c = 0; c < channel; ++c) { for (unsigned int i = 0; i < out_map_size; ++i) { @@ -195,9 +208,9 @@ void Pooling2DLayer::calcDerivative(RunLayerContext &context) { result_data += in_map_size; } } - } break; - case props::PoolingTypeInfo::Enum::global_average: - case props::PoolingTypeInfo::Enum::average: { + }; + + auto apply_average = [&](T *result_data) { int height_stride_end = height - p_height + pt; int width_stride_end = width - p_width + pl; const int *iter = pool_helper.getData(); @@ -207,7 +220,7 @@ void Pooling2DLayer::calcDerivative(RunLayerContext &context) { for (int j = -pt; j <= height_stride_end; j += stride[0]) { K = 0; for (int k = -pl; k <= width_stride_end; k += stride[1]) { - float del = deriv.getValue(b, i, J, K) / *iter; + T del = deriv.getValue(b, i, J, K) / *iter; int patch_height_end = std::min(static_cast(j + p_height), height); int patch_width_end = @@ -217,7 +230,7 @@ void Pooling2DLayer::calcDerivative(RunLayerContext &context) { for (int h = start_h; h < patch_height_end; ++h) { for (int w = start_w; w < patch_width_end; ++w) { result.setValue(b, i, h, w, - result.getValue(b, i, h, w) + del); + result.getValue(b, i, h, w) + del); } } iter++; @@ -227,26 +240,65 @@ void Pooling2DLayer::calcDerivative(RunLayerContext &context) { } } } - } break; - case props::PoolingTypeInfo::Enum::global_max: { - const float *deriv_data = deriv.getData(); + }; + + auto apply_global_max = [&](T *result_data) { + const T *deriv_data = deriv.getData(); for (unsigned int b = 0; b < batch; b++) { for (unsigned int c = 0; c < channel; c++) { const int *iter = pool_helper.getData() + pool_helper.getIndex(b, c, 0, 0); unsigned int helper_size = pool_helper_size[b * channel + c]; - float der = *deriv_data / helper_size; + T der = *deriv_data / static_cast(helper_size); for (unsigned int idx = 0; idx < helper_size; idx++) result_data[iter[idx]] += der; - deriv_data++; result_data += in_map_size; } } - } break; - default: - throw std::runtime_error("Error: Unknown Pooling Type"); + }; + + auto in_data_type = in_dim.getDataType(); + + if (in_data_type == ml::train::TensorDim::DataType::FP32) { + switch (pooling_type) { + case props::PoolingTypeInfo::Enum::max: + apply_max(result.getData()); + break; + case props::PoolingTypeInfo::Enum::global_average: + case props::PoolingTypeInfo::Enum::average: + apply_average(result.getData()); + break; + case props::PoolingTypeInfo::Enum::global_max: + apply_global_max(result.getData()); + break; + default: + throw std::runtime_error("Error: Unknown Pooling Type"); + break; + } + } +#ifdef ENABLE_FP16 + else if (in_data_type == ml::train::TensorDim::DataType::FP16) { + + switch (pooling_type) { + case props::PoolingTypeInfo::Enum::max: + apply_max(result.getData<_FP16>()); + break; + case props::PoolingTypeInfo::Enum::global_average: + case props::PoolingTypeInfo::Enum::average: + apply_average(result.getData<_FP16>()); + break; + case props::PoolingTypeInfo::Enum::global_max: + apply_global_max(result.getData<_FP16>()); + break; + default: + throw std::runtime_error("Error: Unknown Pooling Type"); + } + } +#endif + else { + throw std::runtime_error("Unsupported datatype"); } } @@ -290,124 +342,167 @@ void Pooling2DLayer::pooling2d(Tensor &in, bool training, Tensor &output, * @param start_w (width index pointing the start of the patch) * @return result value of pooling */ - std::function pool_fn; + PoolFunc::Type pool_fn_fp32; +#ifdef ENABLE_FP16 + PoolFunc<_FP16>::Type pool_fn_fp16; +#endif unsigned int max_idx_count = 0; - switch (pooling_type) { - case props::PoolingTypeInfo::Enum::max: { - pool_fn = [&](const float *in_data, int channel_idx, int start_h, - int start_w) { - int end_h = start_h + patch_height; - int end_w = start_w + patch_width; - - float max_val = std::numeric_limits::lowest(); - - int cur_max_idx = -1; - int eff_end_h = std::min(end_h, in_height); - int eff_end_w = std::min(end_w, in_width); - start_w = std::max(0, start_w); - for (int h = std::max(0, start_h); h < eff_end_h; ++h) { - for (int w = start_w; w < eff_end_w; ++w) { - int cur_idx = h * in_width + w; - float val = in_data[cur_idx]; - if (max_val < val) { - max_val = val; - if (training) { - cur_max_idx = cur_idx; - } + + auto pool_fn_max = [&](const T *in_data, int channel_idx, + int start_h, int start_w) { + int end_h = start_h + patch_height; + int end_w = start_w + patch_width; + + T max_val = std::numeric_limits::lowest(); + + int cur_max_idx = -1; + int eff_end_h = std::min(end_h, in_height); + int eff_end_w = std::min(end_w, in_width); + start_w = std::max(0, start_w); + for (int h = std::max(0, start_h); h < eff_end_h; ++h) { + for (int w = start_w; w < eff_end_w; ++w) { + int cur_idx = h * in_width + w; + T val = in_data[cur_idx]; + if (max_val < val) { + max_val = val; + if (training) { + cur_max_idx = cur_idx; } } } + } - if (training) { - pool_helper.setValueInt(max_idx_count++, cur_max_idx); - } + if (training) { + pool_helper.setValueInt(max_idx_count++, cur_max_idx); + } - return max_val; - }; - break; - } - case props::PoolingTypeInfo::Enum::global_max: { - pool_fn = [&, this](const float *in_data, int channel_idx, int start_h, - int start_w) { - int end_h = start_h + patch_height; - int end_w = start_w + patch_width; - - float max_val = std::numeric_limits::lowest(); - int *helper_data = pool_helper.getData(); - helper_data += channel_idx * in_height * in_width; - - for (int h = start_h; h < end_h; ++h) { - for (int w = start_w; w < end_w; ++w) { - int cur_idx = h * in_width + w; - float val = in_data[cur_idx]; - if (max_val < val) { - max_val = val; - max_idx_count = 0; - } + return max_val; + }; - if (training && max_val == val) { - *(helper_data + max_idx_count++) = cur_idx; - } + auto pool_fn_global_max = [&, this](const T *in_data, + int channel_idx, int start_h, + int start_w) { + int end_h = start_h + patch_height; + int end_w = start_w + patch_width; + + T max_val = std::numeric_limits::lowest(); + int *helper_data = pool_helper.getData(); + helper_data += channel_idx * in_height * in_width; + + for (int h = start_h; h < end_h; ++h) { + for (int w = start_w; w < end_w; ++w) { + int cur_idx = h * in_width + w; + T val = in_data[cur_idx]; + if (max_val < val) { + max_val = val; + max_idx_count = 0; } - } - pool_helper_size[batch_idx * in.channel() + channel_idx] = max_idx_count; - return max_val; - }; - break; - } - case props::PoolingTypeInfo::Enum::global_average: - case props::PoolingTypeInfo::Enum::average: { - pool_fn = [&](const float *in_data, int channel_idx, int start_h, - int start_w) { - int end_h = start_h + patch_height; - int end_w = start_w + patch_width; - float total = 0.0f; - - int eff_end_h = std::min(end_h, in_height); - int eff_end_w = std::min(end_w, in_width); - int eff_start_h = std::max(0, start_h); - int eff_start_w = std::max(0, start_w); - - int cnt = (eff_end_h - eff_start_h) * (eff_end_w - eff_start_w); - for (int h = eff_start_h; h < eff_end_h; ++h) { - for (int w = eff_start_w; w < eff_end_w; ++w) { - float val = in_data[h * in_width + w]; - total += val; + if (training && max_val == val) { + *(helper_data + max_idx_count++) = cur_idx; } } + } - if (training) { - pool_helper.setValueInt(max_idx_count++, cnt); + pool_helper_size[batch_idx * in.channel() + channel_idx] = max_idx_count; + return max_val; + }; + + auto pool_fn_average = [&](const T *in_data, int channel_idx, + int start_h, int start_w) { + int end_h = start_h + patch_height; + int end_w = start_w + patch_width; + T total = static_cast(0.0f); + + int eff_end_h = std::min(end_h, in_height); + int eff_end_w = std::min(end_w, in_width); + int eff_start_h = std::max(0, start_h); + int eff_start_w = std::max(0, start_w); + + int cnt = (eff_end_h - eff_start_h) * (eff_end_w - eff_start_w); + for (int h = eff_start_h; h < eff_end_h; ++h) { + for (int w = eff_start_w; w < eff_end_w; ++w) { + T val = in_data[h * in_width + w]; + total += val; } - return total / cnt; - }; + } + + if (training) { + pool_helper.setValueInt(max_idx_count++, cnt); + } + return total / cnt; + }; + + switch (pooling_type) { + case props::PoolingTypeInfo::Enum::max: + pool_fn_fp32 = pool_fn_max; +#ifdef ENABLE_FP16 + pool_fn_fp16 = pool_fn_max; +#endif + break; + case props::PoolingTypeInfo::Enum::global_max: + pool_fn_fp32 = pool_fn_global_max; +#ifdef ENABLE_FP16 + pool_fn_fp16 = pool_fn_global_max; +#endif + break; + case props::PoolingTypeInfo::Enum::global_average: + case props::PoolingTypeInfo::Enum::average: + pool_fn_fp32 = pool_fn_average; +#ifdef ENABLE_FP16 + pool_fn_fp16 = pool_fn_average; +#endif break; - } case props::PoolingTypeInfo::Enum::unknown: default: throw std::invalid_argument("unknown pooling type given"); break; } - const float *in_data = in.getData(); - float *out_data = output.getData(); - - unsigned int map_size = in_height * in_width; - - int height_stride_end = height - patch_height - pt; - int width_stride_end = width - patch_width - pl; - for (unsigned int i = 0; i < channel; ++i) { - const float *in_data_channel_sliced = in_data + i * map_size; - for (int j = -pt; j <= height_stride_end; j += stride[0]) { - for (int k = -pl; k <= width_stride_end; k += stride[1]) { - float pool_value = pool_fn(in_data_channel_sliced, i, j, k); - *out_data = pool_value; - out_data++; + if (in.getDataType() == ml::train::TensorDim::DataType::FP32) { + const float *in_data = in.getData(); + float *out_data = output.getData(); + + unsigned int map_size = in_height * in_width; + + int height_stride_end = height - patch_height - pt; + int width_stride_end = width - patch_width - pl; + for (unsigned int i = 0; i < channel; ++i) { + const float *in_data_channel_sliced = in_data + i * map_size; + for (int j = -pt; j <= height_stride_end; j += stride[0]) { + for (int k = -pl; k <= width_stride_end; k += stride[1]) { + float pool_value = pool_fn_fp32(in_data_channel_sliced, i, j, k); + *out_data = pool_value; + out_data++; + } + } + } + } +#ifdef ENABLE_FP16 + else if (in.getDataType() == ml::train::TensorDim::DataType::FP16) { + const _FP16 *in_data = in.getData<_FP16>(); + _FP16 *out_data = output.getData<_FP16>(); + + unsigned int map_size = in_height * in_width; + + int height_stride_end = height - patch_height - pt; + int width_stride_end = width - patch_width - pl; + for (unsigned int i = 0; i < channel; ++i) { + const _FP16 *in_data_channel_sliced = in_data + i * map_size; + for (int j = -pt; j <= height_stride_end; j += stride[0]) { + for (int k = -pl; k <= width_stride_end; k += stride[1]) { + _FP16 pool_value = pool_fn_fp16(in_data_channel_sliced, i, j, k); + *out_data = pool_value; + out_data++; + } } } } +#endif + else { + throw std::runtime_error("Not supported datatype"); + } } void Pooling2DLayer::setBatch(RunLayerContext &context, unsigned int batch) { diff --git a/nntrainer/layers/reshape_layer.cpp b/nntrainer/layers/reshape_layer.cpp index 0f82d84f3a..af4c254475 100644 --- a/nntrainer/layers/reshape_layer.cpp +++ b/nntrainer/layers/reshape_layer.cpp @@ -42,7 +42,7 @@ void ReshapeLayer::finalize(InitLayerContext &context) { } out_dim.batch(in_dim.batch()); - + out_dim.setDataType(context.getActivationDataType()); context.setOutputDimensions({out_dim}); } diff --git a/nntrainer/layers/time_dist.cpp b/nntrainer/layers/time_dist.cpp index 80451416df..779010065a 100644 --- a/nntrainer/layers/time_dist.cpp +++ b/nntrainer/layers/time_dist.cpp @@ -256,8 +256,8 @@ void TimeDistLayer::forwarding(RunLayerContext &context, bool training) { RunLayerContext dist_context(context.getName(), context.getTrainable(), context.getLoss(), context.executeInPlace(), - getWeightsForContext(), {&in_var}, {&out_var}, - getTensorsForContext()); + context.getLossScale(), getWeightsForContext(), + {&in_var}, {&out_var}, getTensorsForContext()); dist_layer->forwarding(dist_context, training); } @@ -303,8 +303,8 @@ void TimeDistLayer::calcDerivative(RunLayerContext &context) { RunLayerContext dist_context(context.getName(), context.getTrainable(), context.getLoss(), context.executeInPlace(), - getWeightsForContext(), {&in_var}, {&out_var}, - getTensorsForContext()); + context.getLossScale(), getWeightsForContext(), + {&in_var}, {&out_var}, getTensorsForContext()); dist_layer->calcDerivative(dist_context); } @@ -354,8 +354,8 @@ void TimeDistLayer::calcGradient(RunLayerContext &context) { RunLayerContext dist_context(context.getName(), context.getTrainable(), context.getLoss(), context.executeInPlace(), - getWeightsForContext(), {&in_var}, {&out_var}, - getTensorsForContext()); + context.getLossScale(), getWeightsForContext(), + {&in_var}, {&out_var}, getTensorsForContext()); dist_layer->calcGradient(dist_context); } @@ -396,8 +396,8 @@ void TimeDistLayer::setBatch(RunLayerContext &context, unsigned int batch) { RunLayerContext dist_context(context.getName(), context.getTrainable(), context.getLoss(), context.executeInPlace(), - getWeightsForContext(), {&in_var}, {&out_var}, - getTensorsForContext()); + context.getLossScale(), getWeightsForContext(), + {&in_var}, {&out_var}, getTensorsForContext()); dist_layer->setBatch(dist_context, batch); diff --git a/nntrainer/models/model_common_properties.h b/nntrainer/models/model_common_properties.h index 3776afefca..3435d18e96 100644 --- a/nntrainer/models/model_common_properties.h +++ b/nntrainer/models/model_common_properties.h @@ -217,7 +217,7 @@ class ModelTensorDataType final : public EnumProperty { */ class LossScale : public Property { public: - LossScale(float value = 0.0f); + LossScale(float value = 1.0f); static constexpr const char *key = "loss_scale"; /**< unique key to access */ using prop_tag = float_prop_tag; /**< property type */ }; diff --git a/nntrainer/models/neuralnet.cpp b/nntrainer/models/neuralnet.cpp index d0e542825f..de3a0953a4 100644 --- a/nntrainer/models/neuralnet.cpp +++ b/nntrainer/models/neuralnet.cpp @@ -79,7 +79,8 @@ NeuralNetwork::NeuralNetwork() : data_buffers({nullptr, nullptr, nullptr}), initialized(false), compiled(false), - loadedFromConfig(false) { + loadedFromConfig(false), + exec_mode(ExecutionMode::TRAIN) { app_context = AppContext(AppContext::Global()); } @@ -145,7 +146,10 @@ void NeuralNetwork::setTrainConfig(const std::vector &values) { << " of first element: " << left_props.front(); } -int NeuralNetwork::compile() { +int NeuralNetwork::compile(ExecutionMode mode) { + + exec_mode = mode; + std::string loss_type = std::get(model_props).empty() ? std::string() : std::get(model_props); @@ -181,7 +185,7 @@ int NeuralNetwork::compile() { const std::string tensor_type = to_string(std::get(model_flex_props)); - model_graph = NetworkGraph(memory_swap, memory_swap_path, lookahead, + model_graph = NetworkGraph(memory_swap, mode, memory_swap_path, lookahead, tensor_format, tensor_type); model_graph.setMemoryOptimizations( @@ -412,9 +416,21 @@ void NeuralNetwork::backwarding(int iteration, NNTR_THROW_IF(!opt, std::invalid_argument) << "optimizer is null!"; #endif - std::function, int)> backwarding_op = + std::function, bool)> forwarding_op = [this, stop_cb, userdata](std::shared_ptr node, - int iteration) -> void { + bool training) -> void { + (void)this; + PROFILE_MEM_ANNOTATE("Forwarding for layer: " + node->getName()); + + auto f = std::get<0>(node->getExecutionOrder()); + model_graph.flushCacheExcept(f); + + node->forwarding(training); + }; + + std::function, int)> backwarding_op = + [this, stop_cb, userdata](std::shared_ptr node, + int iteration) -> bool { /** * Do not change this order: * 1. calcGradient @@ -448,19 +464,29 @@ void NeuralNetwork::backwarding(int iteration, /** If gradient must be applied and its not gradient mode, calculate * gradient */ - if (!dynamic_training_opt.isGradientMode() && apply_gradient) + if (!dynamic_training_opt.isGradientMode() && apply_gradient) { node->calcGradient(); + + RunLayerContext &rc = node->getRunContext(); + if (rc.isMixedPrecision()) { + for (auto w : rc.getWeights()) { + if (!w->getGradientRef().isValid()) + return false; + } + } + } } model_graph.flushCacheExcept(std::get<2>(node->getExecutionOrder())); PROFILE_MEM_ANNOTATE("CalcDerivative: " + node->getName()); if (stop_cb(userdata)) { - return; + return true; } - if (node->needsCalcDerivative()) + if (node->needsCalcDerivative()) { node->calcDerivative(); + } model_graph.flushCacheExcept(std::get<3>(node->getExecutionOrder())); PROFILE_MEM_ANNOTATE("ApplyGradient: " + node->getName()); @@ -476,9 +502,10 @@ void NeuralNetwork::backwarding(int iteration, opt_->applyGradient(opt_context); }); } + return true; }; - std::function apply_grad_clip_op = + std::function lazy_apply_grad_op = [opt_ = opt.get()](Weight &w, int iteration) -> void { w.calcRegularizationGradient(); w.calcWeightDecayGradient(); @@ -487,8 +514,13 @@ void NeuralNetwork::backwarding(int iteration, opt_->applyGradient(opt_context); }; - model_graph.backwarding(iteration, backwarding_op, apply_grad_clip_op, - stop_cb, userdata); + // return false if the gradient is not valid + bool ret = false; + + while (!ret) { + ret = model_graph.backwarding(iteration, forwarding_op, backwarding_op, + lazy_apply_grad_op, stop_cb, userdata); + } } void NeuralNetwork::save(const std::string &file_path, @@ -555,7 +587,7 @@ void NeuralNetwork::load(const std::string &file_path, auto model_file = checkedOpenStream( file_path, std::ios::in | std::ios::binary); for (auto iter = model_graph.cbegin(); iter != model_graph.cend(); iter++) { - (*iter)->read(model_file); + (*iter)->read(model_file, false, exec_mode); } try { /// this is assuming that the failure is allowed at the end of the file @@ -567,7 +599,7 @@ void NeuralNetwork::load(const std::string &file_path, if (istrequal(opt_type, "adam")) { for (auto iter = model_graph.cbegin(); iter != model_graph.cend(); iter++) { - (*iter)->read(model_file, true); + (*iter)->read(model_file, true, exec_mode); } } } @@ -999,7 +1031,7 @@ int NeuralNetwork::train_run( break; } auto &iteration = iter_view.get(); - if (iteration.batch() != batch_size) { + if (iteration.batch() != static_cast(batch_size)) { /// @todo support partial batch continue; } diff --git a/nntrainer/models/neuralnet.h b/nntrainer/models/neuralnet.h index da1571a328..bda7bcbab7 100644 --- a/nntrainer/models/neuralnet.h +++ b/nntrainer/models/neuralnet.h @@ -162,7 +162,7 @@ class NeuralNetwork : public ml::train::Model { * @retval #ML_ERROR_NONE Successful. * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter. */ - int compile() override; + int compile(ExecutionMode mode = ExecutionMode::TRAIN) override; /** * @brief set Property of Network @@ -682,6 +682,8 @@ s * @retval shared_ptr DynamicTrainingOptimization dynamic_training_opt; /**< Dynamic fine-tuning optimization mode. supported modes are "max" and "norm" */ + ExecutionMode exec_mode; + /** * @brief save model in ini * diff --git a/nntrainer/optimizers/adam.cpp b/nntrainer/optimizers/adam.cpp index 18c0a0fcc1..f7189dda7e 100644 --- a/nntrainer/optimizers/adam.cpp +++ b/nntrainer/optimizers/adam.cpp @@ -36,7 +36,15 @@ Adam::~Adam() {} enum AdamParams { wm, wv }; std::vector Adam::getOptimizerVariableDim(const TensorDim &dim) { - return {dim, dim}; + /** + * @note We assume the optimizer parameters should be full precsion to + * maintain the accuracy even in mixed precision training. + */ + TensorDim wm_dim(dim); + TensorDim wv_dim(dim); + wm_dim.setDataType(ml::train::TensorDim::DataType::FP32); + wv_dim.setDataType(ml::train::TensorDim::DataType::FP32); + return {wm_dim, wv_dim}; } void Adam::exportTo(Exporter &exporter, @@ -64,7 +72,17 @@ double Adam::getUpdatedLearningRate(unsigned int iteration, double ll) const { } void Adam::applyGradient(RunOptimizerContext &context) { - Tensor &x_grad = context.getGradient(); + Tensor empty_tensor; + + Tensor &x_grad = + context.getGradient().getDataType() == ml::train::TensorDim::DataType::FP32 + ? context.getGradient() + : empty_tensor; + + if (x_grad.empty()) { + x_grad = context.getGradient().clone(ml::train::TensorDim::DataType::FP32); + context.applyLossScale(x_grad); + } auto &beta1 = std::get(adam_props).get(); auto &beta2 = std::get(adam_props).get(); @@ -91,7 +109,7 @@ void Adam::applyGradient(RunOptimizerContext &context) { denom.add_i(epsilon); wm.divide(denom, x_grad); - context.applyGradient(context.getLearningRate() / biasCorrection1); + context.applyGradient(context.getLearningRate() / biasCorrection1, x_grad); } else { std::function sqrtEps = [epsilon](double f) { @@ -100,8 +118,9 @@ void Adam::applyGradient(RunOptimizerContext &context) { x_grad = wv.apply(sqrtEps, x_grad); x_grad.multiply_i(wm); - context.applyGradient(getUpdatedLearningRate(context.getIteration(), - context.getLearningRate())); + context.applyGradient( + getUpdatedLearningRate(context.getIteration(), context.getLearningRate()), + x_grad); } } diff --git a/nntrainer/optimizers/optimizer_context.cpp b/nntrainer/optimizers/optimizer_context.cpp index da4cd1f7e9..8380ad6613 100644 --- a/nntrainer/optimizers/optimizer_context.cpp +++ b/nntrainer/optimizers/optimizer_context.cpp @@ -42,4 +42,24 @@ Tensor &RunOptimizerContext::getOptimizerVariable(unsigned int idx) const { void RunOptimizerContext::applyGradient(double lr) const { weight->applyGradient(lr); } + +/** + * @brief Apply the gradient with the given learning rate and gradient + */ +void RunOptimizerContext::applyGradient(double lr, Tensor &updated_grad) const { + weight->applyGradient(lr, updated_grad); +} + +/** + * @brief Apply loss scale to gradient (full precision) + */ +void RunOptimizerContext::applyLossScale(Tensor &fp32_grad) { + if (!weight->isMixedPrecision()) + return; + if (fp32_grad.getDataType() != ml::train::TensorDim::DataType::FP32) + throw std::invalid_argument( + "gradient should be fullprecsion to maintain accuracy"); + float loss_scale = weight->getLossScale(); + fp32_grad.divide_i(loss_scale); +} } // namespace nntrainer diff --git a/nntrainer/optimizers/optimizer_context.h b/nntrainer/optimizers/optimizer_context.h index 62f9e0945d..27f028fc52 100644 --- a/nntrainer/optimizers/optimizer_context.h +++ b/nntrainer/optimizers/optimizer_context.h @@ -35,9 +35,7 @@ class RunOptimizerContext { * */ RunOptimizerContext(Weight *w = nullptr, size_t iter = 0, double lr = 0.0) : - weight(w), - iteration(iter), - learning_rate(lr) {} + weight(w), iteration(iter), learning_rate(lr) {} /** * @brief Get the Weight tensor object @@ -75,6 +73,16 @@ class RunOptimizerContext { */ void applyGradient(double lr) const; + /** + * @brief Apply the gradient with the given learning rate and updated + * gradient + * + * @param lr learning rate + * @param updated_grad gradient tensor which is updated. (usually it could be + * fp32) + */ + void applyGradient(double lr, Tensor &updated_grad) const; + /** * @brief Get the current iteration value * @@ -89,6 +97,11 @@ class RunOptimizerContext { */ double getLearningRate() const { return learning_rate; } + /** + * @brief Apply loss scale to gradient (full precision) + */ + void applyLossScale(Tensor &fp32_grad); + private: Weight *weight; /**< weights for the optimizer */ size_t iteration; /**< iteration number */ diff --git a/nntrainer/optimizers/sgd.cpp b/nntrainer/optimizers/sgd.cpp index 8b0078e9e6..e4b2209a57 100644 --- a/nntrainer/optimizers/sgd.cpp +++ b/nntrainer/optimizers/sgd.cpp @@ -16,7 +16,20 @@ namespace nntrainer { void SGD::applyGradient(RunOptimizerContext &context) { - context.applyGradient(context.getLearningRate()); + // @todo This could go inside the context. + Tensor empty_tensor; + + Tensor &x_grad = + context.getGradient().getDataType() == ml::train::TensorDim::DataType::FP32 + ? context.getGradient() + : empty_tensor; + + if (x_grad.empty()) { + x_grad = context.getGradient().clone(ml::train::TensorDim::DataType::FP32); + context.applyLossScale(x_grad); + } + + context.applyGradient(context.getLearningRate(), x_grad); } } // namespace nntrainer diff --git a/nntrainer/tensor/blas_avx.cpp b/nntrainer/tensor/blas_avx.cpp index ce59583d6f..411dbcbb5d 100644 --- a/nntrainer/tensor/blas_avx.cpp +++ b/nntrainer/tensor/blas_avx.cpp @@ -20,6 +20,7 @@ namespace nntrainer::avx { +#ifdef ENABLE_FP16 void vcvt_f16_f32(size_t N, const void *input, float *output) { assert(N != 0); assert(input != NULL); @@ -114,4 +115,163 @@ void vcvt_f32_f16(size_t N, const float *input, void *output) { } } +bool isValid(const size_t N, const _Float16 *input) { + assert(N != 0); + assert(input != NULL); + + int temp = 0; + size_t idx = 0; + + const __m256 SIGN_MASK = _mm256_set1_ps(-0.0); + const __m256 INF = _mm256_set1_ps(std::numeric_limits::infinity()); + + // 16 single-precision check : ( X != X ) + for (; N - idx >= 16; idx += 16) { + __m256 vec0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)input)); + __m256 vec1 = + _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(input + 8))); + + input += 16; + + // check NaN in vec0 + __m256 res = _mm256_cmp_ps(vec0, vec0, _CMP_NEQ_UQ); + temp = temp | _mm256_movemask_ps(res); + if (temp) + return false; + + // check infinity in vec0 + vec0 = _mm256_andnot_ps(SIGN_MASK, vec0); + vec0 = _mm256_cmp_ps(vec0, INF, _CMP_EQ_OQ); + + temp = temp | _mm256_movemask_ps(vec0); + if (temp) + return false; + + // check NaN in vec1 + __m256 res1 = _mm256_cmp_ps(vec1, vec1, _CMP_NEQ_UQ); + temp = temp | _mm256_movemask_ps(res1); + + if (temp) + return false; + + // check infinity in vec1 + vec1 = _mm256_andnot_ps(SIGN_MASK, vec1); + vec1 = _mm256_cmp_ps(vec1, INF, _CMP_EQ_OQ); + + temp = temp | _mm256_movemask_ps(vec1); + + if (temp) + return false; + } + + // 8 single-precision check : ( X != X ) + for (; N - idx >= 8; idx += 8) { + __m256 vec = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)input)); + input += 8; + __m256 res = _mm256_cmp_ps(vec, vec, _CMP_NEQ_UQ); + temp = temp | _mm256_movemask_ps(res); + + if (temp) + return false; + + // check infinity in vec1 + vec = _mm256_andnot_ps(SIGN_MASK, vec); + vec = _mm256_cmp_ps(vec, INF, _CMP_EQ_OQ); + + temp = temp | _mm256_movemask_ps(vec); + + if (temp) + return false; + } + + // remain check : ( X != X || X == Inf ) + while (idx < N) { + if (*input != *input || *input == std::numeric_limits::infinity()) { + return false; + } + ++input; + ++idx; + } + + return true; +} +#endif + +bool isValid(const size_t N, const float *input) { + assert(N != 0); + assert(input != NULL); + + int temp = 0; + size_t idx = 0; + + const __m256 SIGN_MASK = _mm256_set1_ps(-0.0); + const __m256 INF = _mm256_set1_ps(std::numeric_limits::infinity()); + + // 16 single-precision check : ( X != X ) + for (; N - idx >= 16; idx += 16) { + __m256 vec0 = _mm256_loadu_ps(input); + __m256 vec1 = _mm256_loadu_ps(input + 8); + input += 16; + __m256 res = _mm256_cmp_ps(vec0, vec0, _CMP_NEQ_UQ); + temp = temp | _mm256_movemask_ps(res); + + if (temp) + return false; + + // check infinity in vec0 + vec0 = _mm256_andnot_ps(SIGN_MASK, vec0); + vec0 = _mm256_cmp_ps(vec0, INF, _CMP_EQ_OQ); + + temp = temp | _mm256_movemask_ps(vec0); + if (temp) + return false; + + __m256 res1 = _mm256_cmp_ps(vec1, vec1, _CMP_NEQ_UQ); + temp = temp | _mm256_movemask_ps(res1); + + if (temp) + return false; + + // check infinity in vec1 + vec1 = _mm256_andnot_ps(SIGN_MASK, vec1); + vec1 = _mm256_cmp_ps(vec1, INF, _CMP_EQ_OQ); + + temp = temp | _mm256_movemask_ps(vec1); + + if (temp) + return false; + } + + // 8 single-precision check : ( X != X ) + for (; N - idx >= 8; idx += 8) { + __m256 vec = _mm256_loadu_ps(input); + input += 8; + __m256 res = _mm256_cmp_ps(vec, vec, _CMP_NEQ_UQ); + temp = temp | _mm256_movemask_ps(res); + + if (temp) + return false; + + // check infinity in vec + vec = _mm256_andnot_ps(SIGN_MASK, vec); + vec = _mm256_cmp_ps(vec, INF, _CMP_EQ_OQ); + + temp = temp | _mm256_movemask_ps(vec); + + if (temp) + return false; + } + + // remain check : ( X != X ) + while (idx < N) { + if (*input != *input || *input == std::numeric_limits::infinity()) { + return false; + } + ++input; + ++idx; + } + + return true; +} + } // namespace nntrainer::avx diff --git a/nntrainer/tensor/blas_avx.h b/nntrainer/tensor/blas_avx.h index ab1270a208..5eabcbdb2c 100644 --- a/nntrainer/tensor/blas_avx.h +++ b/nntrainer/tensor/blas_avx.h @@ -20,6 +20,7 @@ namespace nntrainer::avx { +#ifdef ENABLE_FP16 /** * @brief Converts half-precision floating point values to single-precision * floating point values. @@ -40,6 +41,25 @@ void vcvt_f16_f32(size_t N, const void *input, float *output); */ void vcvt_f32_f16(size_t N, const float *input, void *output); +/** + * @brief check if the X has NaN value + * @note it compare (x!=x || x == inf) + * @param[in] N length of the vector + * @param[in] X half-precision * for Vector X + * @param[out] false if it has NaN or inf + */ +bool isValid(const size_t N, const _Float16 *X); +#endif + +/** + * @brief check if the X has NaN value + * @note it compare (x!=x || x == inf) + * @param[in] N length of the vector + * @param[in] X float * for Vector X + * @param[out] false if it has NaN or inf + */ +bool isValid(const size_t N, const float *X); + } // namespace nntrainer::avx #endif /* __cplusplus */ diff --git a/nntrainer/tensor/blas_interface.cpp b/nntrainer/tensor/blas_interface.cpp index e04c1ce499..224e433d42 100644 --- a/nntrainer/tensor/blas_interface.cpp +++ b/nntrainer/tensor/blas_interface.cpp @@ -864,7 +864,10 @@ void scopy(const unsigned int N, const float *X, const int incX, float *Y, #ifdef BLAS_NUM_THREADS openblas_set_num_threads(BLAS_NUM_THREADS); #endif - cblas_scopy(N, X, incX, Y, incY); + // cblas_scopy(N, (float*)(X), incX, (float*)(Y), incY); + // replace cblas scopy with raw temporary. + for (unsigned int i = 0; i < N; ++i) + Y[i * incY] = X[i * incX]; #else scopy_raw(N, X, incX, Y, incY); #endif @@ -1060,6 +1063,16 @@ static void ele_div_fallback(const unsigned int N, const float *X, } } +static bool is_valid_fallback(const size_t N, const float *X) { + for (size_t i = 0; i < N; ++i) { + if (*X != *X || *X == std::numeric_limits::infinity()) + return false; + ++X; + } + + return true; +} + void ele_mul(const unsigned int N, const float *X, const float *Y, float *Z, float alpha, float beta, unsigned int i_stride, unsigned int o_stride) { @@ -1112,4 +1125,30 @@ void ele_div(const unsigned int N, const float *X, const float *Y, float *Z, ele_div_fallback(N, X, Y, Z, alpha, beta, i_stride, o_stride); } +bool is_valid(const size_t N, ml::train::TensorDim::DataType d_type, + const void *X) { + if (d_type == ml::train::TensorDim::DataType::FP16) { +#ifdef ENABLE_FP16 + const _FP16 *vec = (const _FP16 *)X; +#ifdef USE_NEON + return nntrainer::neon::isValid(N, vec); +#elif defined(USE_AVX) + return nntrainer::avx::isValid(N, vec); +#else + throw std::invalid_argument("Error: enable-fp16 is not enabled"); +#endif +#endif + } else if (d_type == ml::train::TensorDim::DataType::FP32) { + const float *vec = (const float *)X; +#ifdef USE_NEON + return nntrainer::neon::isValid(N, vec); +#elif defined(USE_AVX) + return nntrainer::avx::isValid(N, vec); +#endif + + return is_valid_fallback(N, vec); + } + return false; +} + } // namespace nntrainer diff --git a/nntrainer/tensor/blas_interface.h b/nntrainer/tensor/blas_interface.h index 69cdda01f9..d270a07919 100644 --- a/nntrainer/tensor/blas_interface.h +++ b/nntrainer/tensor/blas_interface.h @@ -492,6 +492,16 @@ void ele_sub(const unsigned N, const float *X, const float *Y, float *Z, void ele_div(const unsigned N, const float *X, const float *Y, float *Z, float alpha = 1.f, float beta = 0.f, unsigned int i_stride = 1, unsigned int o_stride = 1); + +/** + * @brief check if X array has NaN or inf + * @param[in] N length of the vector + * @param[in] X float/fp16 * for Vector X + * @param[out] bool false if not valide else true + */ +bool is_valid(const size_t N, ml::train::TensorDim::DataType d_type, + const void *X); + } /* namespace nntrainer */ #endif /* __cplusplus */ #endif /* __BLAS_INTERFACE_H__ */ diff --git a/nntrainer/tensor/blas_neon.cpp b/nntrainer/tensor/blas_neon.cpp index 4a20031195..dc8c3474b4 100644 --- a/nntrainer/tensor/blas_neon.cpp +++ b/nntrainer/tensor/blas_neon.cpp @@ -546,6 +546,36 @@ void ele_div(const unsigned N, const float *X, const float *Y, float *Z, } } +bool isValid(const size_t N, const float *X) { + size_t i = 0; + float inf_s = std::numeric_limits::infinity(); + float32x4_t inf = vdupq_n_f32(inf_s); + uint16x8_t zero = vdupq_n_f32(0); + + for (; N - i >= 4; i += 4) { + float32x4_t vec = vld1q_f32(&X[i]); + uint32x4_t vcmp = vceqq_f32(vec, vec); + + vcmp = vceqq_f32(vcmp, zero); + + if (vaddvq_u32(vcmp)) + return false; + + vcmp = vceqq_f32(vec, inf); + + if (vaddvq_u16(vcmp)) + return false; + } + + while (i < N) { + if (X[i] != X[i] || X[i] == std::numeric_limits::infinity()) + return false; + ++i; + } + + return true; +} + #ifdef ENABLE_FP16 void hgemv(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t M, uint32_t N, @@ -2025,5 +2055,40 @@ void inv_sqrt_inplace(const unsigned int N, __fp16 *X) { } } +bool isValid(const size_t N, const __fp16 *input) { + bool temp = 0; + size_t i = 0; + __fp16 inf_s = std::numeric_limits::infinity(); + float16x8_t inf = vdupq_n_f16(inf_s); + uint16x8_t zero = vdupq_n_f16(0); + + for (; N - i >= 8; i += 8) { + float16x8_t vec = vld1q_f16(&input[i]); + + uint16x8_t vcmp = vceqq_f16(vec, vec); + + vcmp = vceqq_f16(vcmp, zero); + + if (vaddvq_u16(vcmp)) { + return false; + } + + vcmp = vceqq_f16(vec, inf); + + if (vaddvq_u16(vcmp)) { + return false; + } + } + + while (i < N) { + if (input[i] != input[i] || + input[i] == std::numeric_limits::infinity()) { + return false; + } + ++i; + } + return true; +} + #endif } // namespace nntrainer::neon diff --git a/nntrainer/tensor/blas_neon.h b/nntrainer/tensor/blas_neon.h index db1b6a5ccc..978d3428f7 100644 --- a/nntrainer/tensor/blas_neon.h +++ b/nntrainer/tensor/blas_neon.h @@ -148,6 +148,15 @@ void ele_sub(const unsigned N, const float *X, const float *Y, float *Z, void ele_div(const unsigned N, const float *X, const float *Y, float *Z, float alpha = 1.f, float beta = 0.f); +/** + * @brief check if the X has NaN value or Inf + * @note it compare (x!=x || x == inf) + * @param[in] N length of the vector + * @param[in] input float * for Vector X + * @param[out] false if it has NaN or Inf + */ +bool isValid(const size_t N, const float *input); + #ifdef ENABLE_FP16 /** * @brief hgemv computation with neon : Y = alpha*A*X + beta*Y @@ -380,6 +389,15 @@ void hgemm_transAB(const __fp16 *A, const __fp16 *B, float *C, uint32_t M, * @param X __fp16 * for Vector X */ void inv_sqrt_inplace(const unsigned int N, __fp16 *X); + +/** + * @brief check if the X is valid: Check NaN or Inf + * @note it compare (x!=x || x == inf) + * @param[in] N length of the vector + * @param[in] X float * for Vector X + * @param[out] false if it has NaN or Inf + */ +bool isValid(const size_t N, const __fp16 *X); #endif } // namespace nntrainer::neon diff --git a/nntrainer/tensor/manager.cpp b/nntrainer/tensor/manager.cpp index 9a0d235ba9..f77a012b49 100644 --- a/nntrainer/tensor/manager.cpp +++ b/nntrainer/tensor/manager.cpp @@ -407,14 +407,15 @@ std::vector Manager::requestWeights( * order with the max exec order where it will be used for clipping and then * applied to the weight. */ - if (Weight::isGradientClipByGlobalNorm(clip_by_global_norm)) { + if (Weight::isGradientClipByGlobalNorm(clip_by_global_norm) || + isMixedPrecision()) { grad_exec_order.push_back(TensorPool::PERSIST_END_ORDER); // TODO: We need double check if it is OK not to add PERSIST_END_ORDER // here or add other conditions // var_exec_order.push_back(TensorPool::PERSIST_END_ORDER); } - Tensor *var = nullptr, *grad = nullptr; + Tensor *var = nullptr, *grad = nullptr, *var32 = nullptr; bool is_dependent = !shared_names.empty(); if (is_dependent) { /// shared_name is used and the orignal name is discarded @@ -431,6 +432,17 @@ std::vector Manager::requestWeights( grad = tensor_pool.requestOrExtend(shared_name + Var_Grad::grad_suffix, dim_g, grad_exec_order, grad_ls, Tensor::Initializer::ZEROS); + + if (var->getDataType() != ml::train::TensorDim::DataType::FP32) { + TensorDim var32_dim(dim_v); + var32_dim.setDataType(ml::train::TensorDim::DataType::FP32); + std::vector var32_exec_order; + var32_exec_order.push_back(TensorPool::PERSIST_END_ORDER); + + var32 = weight_pool.requestOrExtend(shared_name + ":var32", var32_dim, + var32_exec_order, var_ls, + Tensor::Initializer::ZEROS); + } } } else { /** case requesting fresh weights */ @@ -443,16 +455,26 @@ std::vector Manager::requestWeights( * reduce the memory. */ bool is_wgrad = true; - if (Weight::isGradientClipByGlobalNorm(clip_by_global_norm)) - is_wgrad = false; + // if (Weight::isGradientClipByGlobalNorm(clip_by_global_norm)) + // is_wgrad = false; grad = tensor_pool.request(name + Var_Grad::grad_suffix, dim_g, grad_exec_order, grad_ls, Tensor::Initializer::ZEROS, is_wgrad); + if (var->getDataType() != ml::train::TensorDim::DataType::FP32) { + TensorDim var32_dim(dim_v); + var32_dim.setDataType(ml::train::TensorDim::DataType::FP32); + std::vector var32_exec_order; + var32_exec_order.push_back(TensorPool::PERSIST_END_ORDER); + var32 = + weight_pool.request(name + ":var32", var32_dim, var32_exec_order, + var_ls, Tensor::Initializer::ZEROS); + } } } weights_v2.emplace_back(std::make_unique( - var, grad, w_reg, w_reg_const, decay, is_dependent, clip_by_global_norm)); + var, grad, var32, w_reg, w_reg_const, decay, is_dependent, + clip_by_global_norm, axis, loss_scale)); } std::transform(weights_v2.begin() + current_size, weights_v2.end(), @@ -668,15 +690,15 @@ bool Manager::isSecondLastAccess(const std::string &name, */ std::vector Manager::requestWeightOptimizerVariables( const std::vector &dims, const std::string &name, - const TensorLifespan &lifespan, bool is_grad_clip, - Tensor::Initializer initializer) { + const std::string &suffix, const TensorLifespan &lifespan, bool is_grad_clip, + bool is_mixed_precision, Tensor::Initializer initializer) { std::vector ret; ret.reserve(dims.size()); std::vector exec; exec.reserve(1); - if (is_grad_clip) { + if (is_grad_clip || is_mixed_precision) { exec.emplace_back(TensorPool::PERSIST_END_ORDER); } else { exec.emplace_back(getMinMaxTensorExecutionOrder(name, true).second); @@ -685,7 +707,7 @@ std::vector Manager::requestWeightOptimizerVariables( /// @note this is assuming weight optimizer variables is treated as weight, if /// not, there is room to optimize below behavior for (unsigned int idx = 0; idx < dims.size(); idx++) - ret.push_back(weight_pool.request(name + ":opt" + std::to_string(idx), + ret.push_back(weight_pool.request(name + suffix + std::to_string(idx), dims[idx], exec, lifespan, initializer)); return ret; diff --git a/nntrainer/tensor/manager.h b/nntrainer/tensor/manager.h index ab1c018153..d561770206 100644 --- a/nntrainer/tensor/manager.h +++ b/nntrainer/tensor/manager.h @@ -224,7 +224,8 @@ class Manager { */ std::vector requestWeightOptimizerVariables( const std::vector &dims, const std::string &name, - const TensorLifespan &lifespan, bool is_grad_clip, + const std::string &suffix, const TensorLifespan &lifespan, + bool is_grad_clip, bool is_mixed_type, Tensor::Initializer initializer = Tensor::Initializer::NONE); /** @@ -494,6 +495,11 @@ class Manager { exec_mode = mode; }; + /** + * @brief return if it is mixed precsion + */ + bool isMixedPrecision() { return !istrequal(tensor_dtype[0], "FP32"); } + private: /** @todo: merge this list to one */ std::vector> weights_v2; /**< weights for the layers diff --git a/nntrainer/tensor/meson.build b/nntrainer/tensor/meson.build index fe4204cf85..8f2b6e9583 100644 --- a/nntrainer/tensor/meson.build +++ b/nntrainer/tensor/meson.build @@ -44,6 +44,12 @@ cl_headers = [ arch = host_machine.cpu_family() + +if get_option('enable-avx') + tensor_sources += 'blas_avx.cpp' + tensor_headers += 'blas_avx.h' +endif + if get_option('enable-fp16') if arch == 'arm' error ('FP16/ARM code (blas_neon.cpp) uses armv8.2 instructions. armv7 is not supported.') @@ -59,9 +65,6 @@ if get_option('enable-fp16') nntrainer_inc += include_directories('matrix_transpose_neon') nntrainer_inc_abs += meson.current_source_dir() / 'matrix_transpose_neon' endif - elif get_option('enable-avx') - tensor_sources += 'blas_avx.cpp' - tensor_headers += 'blas_avx.h' endif endif diff --git a/nntrainer/tensor/tensor.cpp b/nntrainer/tensor/tensor.cpp index 5dc3c93f01..a69c6c52d4 100644 --- a/nntrainer/tensor/tensor.cpp +++ b/nntrainer/tensor/tensor.cpp @@ -3071,6 +3071,18 @@ Tensor Tensor::clone() const { return t; } +Tensor Tensor::clone(ml::train::TensorDim::DataType type) const { + if (getDataType() == type) + return clone(); + + TensorDim dim = getDim(); + dim.setDataType(type); + Tensor t(dim, true); + t.copyData(*this); + t.name = name; + return t; +} + void Tensor::reshape(const TensorDim &d) { NNTR_THROW_IF(!contiguous, std::invalid_argument) @@ -3313,13 +3325,17 @@ void Tensor::setValue(float val) { void Tensor::setZero() { if (dim.getDataType() == ml::train::TensorDim::DataType::FP32) { if (contiguous) - sscal(size(), 0, getData(), 1); + // sscal(size(), 0, getData(), 1); + /// @note we cannot use sscal, when we set zero. if the data is inf or + /// NaN, then the inf or NaN still remain. + memset(getData(), 0, sizeof(float) * size()); else apply_i([](float val) -> float { return 0; }); } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) { #ifdef ENABLE_FP16 if (contiguous) - sscal(size(), 0, getData<_FP16>(), 1); + // sscal(size(), 0, getData<_FP16>(), 1); + memset(getData<_FP16>(), 0, sizeof(_FP16) * size()); else apply_i<_FP16>([](_FP16 val) -> _FP16 { return 0; }); #else @@ -3814,6 +3830,18 @@ void Tensor::dequantize(Tensor &output, unsigned int axis) const { return; } +bool Tensor::isValid() const { + if (getDataType() == Tdatatype::FP16) { +#ifdef ENABLE_FP16 + return is_valid(dim.getDataLen(), Tdatatype::FP16, getData<_FP16>()); +#else + throw std::invalid_argument("enble-fp16 is not set"); +#endif + } else { + return is_valid(dim.getDataLen(), Tdatatype::FP32, getData()); + } +} + // namespace nntrainer } /* namespace nntrainer */ diff --git a/nntrainer/tensor/tensor.h b/nntrainer/tensor/tensor.h index 211334da40..ad3781526f 100644 --- a/nntrainer/tensor/tensor.h +++ b/nntrainer/tensor/tensor.h @@ -1680,6 +1680,13 @@ class Tensor { */ Tensor clone() const; + /** + * @brief Convient wrapper for inplace copy of @a this. + * @param[in] type output tensor data type + * @retval Copied version of this + */ + Tensor clone(ml::train::TensorDim::DataType type) const; + /** * @brief Save the Tensor into file * @param[in] file output file stream @@ -2031,6 +2038,12 @@ class Tensor { static constexpr float epsilon = 1e-5; + /** + * @brief check if there is NaN or Inf element + * @param[out] bool false if there is NaN or Inf else false + */ + bool isValid() const; + private: /**< handle the data as a std::shared_ptr type */ TensorDim dim; diff --git a/nntrainer/tensor/weight.cpp b/nntrainer/tensor/weight.cpp index f98c8c8356..ea8c65a7cb 100644 --- a/nntrainer/tensor/weight.cpp +++ b/nntrainer/tensor/weight.cpp @@ -34,6 +34,28 @@ Weight::Weight(const TensorDim &dim, const Tensor::Initializer init, throw std::invalid_argument("Weight initializer cannot be none"); if (regularizer == WeightRegularizer::UNKNOWN) throw std::invalid_argument("Weight regularizer unknown"); + + std::string var32_suffix = ":fp32"; + std::string var32_name = name + var32_suffix; + + /** + * @note We assume if the Weight Data Type is not FP32, then FP32 Weight is + * necessary to maintain the accuracy. + * We could think it can be other data type and if there is the case to + * support other data type, then the code below needs to be udpated. + * + * Also, the loss_scale is not used in Weight but leave as it is for later + * usage. + */ + + if (train && dim.getDataType() != ml::train::TensorDim::DataType::FP32) { + TensorDim var32_dim(dim); + var32_dim.setDataType(ml::train::TensorDim::DataType::FP32); + + var32 = std::make_shared(var32_dim, alloc_now_, init, var32_name); + } else { + var32 = std::make_shared(var32_name); + } } Weight::Weight(const TensorDim &dim_v, const TensorDim &dim_g, @@ -52,6 +74,93 @@ Weight::Weight(const TensorDim &dim_v, const TensorDim &dim_g, throw std::invalid_argument("Weight initializer cannot be none"); if (regularizer == WeightRegularizer::UNKNOWN) throw std::invalid_argument("Weight regularizer unknown"); + + std::string var32_suffix = ":fp32"; + std::string var32_name = name + var32_suffix; + + if (train && dim_v.getDataType() != ml::train::TensorDim::DataType::FP32) { + TensorDim var32_dim(dim_v); + var32_dim.setDataType(ml::train::TensorDim::DataType::FP32); + std::string var32_suffix = ":fp32"; + std::string var32_name = name + var32_suffix; + + var32 = std::make_shared(var32_dim, alloc_now_, init, var32_name); + } else { + var32 = std::make_shared(var32_name); + } +} + +Weight::Weight(const Tensor &v, const Tensor &g, const Tensor &v32, + const std::string &n, bool is_dependent, + unsigned int output_axis_) : + Var_Grad(v, g, n, is_dependent), + regularizer(WeightRegularizer::NONE), + regularizer_constant(1.0f), + decay(0.0f), + clip_by_global_norm(0.0f), + output_axis(output_axis_), + loss_scale(1.0), + var32(std::make_shared(n + ":fp32")) { + + if (!g.empty() && isMixedPrecision()) { + TensorDim var32_dim(v.getDim()); + var32_dim.setDataType(ml::train::TensorDim::DataType::FP32); + if (!v32.empty()) + var32 = std::make_shared( + v32.getSharedDataTensor(var32_dim, 0, false, n + ":fp32")); + } +} + +Weight::Weight(Tensor *v, Tensor *g, Tensor *v32, const WeightRegularizer reg, + const float reg_const, const float decay, bool is_dependent, + const float max_norm, unsigned int output_axis_, + float loss_scale_) : + Var_Grad(v, g, is_dependent), + regularizer(reg), + regularizer_constant(reg_const), + decay(decay), + clip_by_global_norm(max_norm), + output_axis(output_axis_), + loss_scale(loss_scale_), + var32(std::shared_ptr(v32, [](void *) {})) { + if (!v32) + var32 = std::make_shared(); +} + +void Weight::applyGradient(double lr, Tensor &updated_grad) { + if (isMixedPrecision() && + updated_grad.getDataType() == ml::train::TensorDim::DataType::FP32) { + var32->add_i(updated_grad, -lr); + quantizeWeight(); + return; + } + + return applyGradient(lr); +} + +void Weight::quantizeWeight() { + if (!isMixedPrecision()) + return; + + Tensor &var = getVariableRef(); + ml::train::TensorDim::DataType type = var.getDataType(); + switch (type) { + case ml::train::TensorDim::DataType::QINT4: + // NYI + break; + case ml::train::TensorDim::DataType::QINT8: + // NYI + break; + case ml::train::TensorDim::DataType::FP16: + getVariableRef().copyData(getVariableFP32Ref()); + break; + case ml::train::TensorDim::DataType::FP32: + break; + default: + break; + } + + return; } } // namespace nntrainer diff --git a/nntrainer/tensor/weight.h b/nntrainer/tensor/weight.h index 552f6d5739..ef65ca9318 100644 --- a/nntrainer/tensor/weight.h +++ b/nntrainer/tensor/weight.h @@ -46,7 +46,7 @@ class Weight : public Var_Grad { decay(0.0f), clip_by_global_norm(0.0f), output_axis(3), - loss_scale(0.0) {} + loss_scale(1.0) {} /** * @brief Construct a new Weight object @@ -66,7 +66,7 @@ class Weight : public Var_Grad { const float reg_const = 1.0f, const float decay = 0.0f, const float clip_by_global_norm = 0.0f, bool ng = true, bool alloc_now = false, std::string name = "", unsigned int axis = 3, - float loss_scale_ = 0.0); + float loss_scale_ = 1.0); /** * @brief Construct a new Weight object @@ -87,7 +87,7 @@ class Weight : public Var_Grad { const float reg_const = 1.0f, const float decay = 0.0f, const float clip_by_global_norm = 0.0f, bool ng = true, bool alloc_now = false, std::string name = "", unsigned int axis = 3, - float loss_scale_ = 0.0); + float loss_scale_ = 1.0); /** * @brief Construct a new Weight object @@ -114,6 +114,7 @@ class Weight : public Var_Grad { * * @param v Already created variable object * @param g Already created gradient object + * @param v32 Already created gradient object * @param n Name for this Weight * * @note This is primarily used to created wrapper of variable extracted from @@ -123,35 +124,24 @@ class Weight : public Var_Grad { * uses only, as Weight does not own the tensors v and g, and can go invalid * if the owner of these tensors free the tensors. */ - explicit Weight(const Tensor &v, const Tensor &g, const std::string &n = "", - bool is_dependent = false, unsigned int output_axis_ = 3) : - Var_Grad(v, g, n, is_dependent), - regularizer(WeightRegularizer::NONE), - regularizer_constant(1.0f), - decay(0.0f), - clip_by_global_norm(0.0f), - output_axis(output_axis_), - loss_scale(0.0) {} + explicit Weight(const Tensor &v, const Tensor &g, const Tensor &v32, + const std::string &n = "", bool is_dependent = false, + unsigned int output_axis_ = 3); /** * @brief Construct a new Weight object * * @param v ptr to already created variable tensor * @param g ptr to already created gradient tensor + * @param v32 ptr to already created variable32 tensor * @param reg Regularizer for the weight * @param reg_const Constant multiplier for regularizer */ - explicit Weight(Tensor *v, Tensor *g, const WeightRegularizer reg, - const float reg_const, const float decay, - bool is_dependent = false, const float max_norm = 0.0f, - unsigned int output_axis_ = 3, float loss_scale_ = 0.0f) : - Var_Grad(v, g, is_dependent), - regularizer(reg), - regularizer_constant(reg_const), - decay(decay), - clip_by_global_norm(max_norm), - output_axis(output_axis_), - loss_scale(loss_scale_) {} + explicit Weight(Tensor *v, Tensor *g, Tensor *v32, + const WeightRegularizer reg, const float reg_const, + const float decay, bool is_dependent = false, + const float max_norm = 0.0f, unsigned int output_axis_ = 3, + float loss_scale_ = 1.0f); /** * @brief Swap for weight @@ -170,6 +160,7 @@ class Weight : public Var_Grad { swap(lhs.output_axis, rhs.output_axis); swap(lhs.opt_vars, rhs.opt_vars); swap(lhs.loss_scale, rhs.loss_scale); + swap(lhs.var32, rhs.var32); } /** @@ -213,6 +204,8 @@ class Weight : public Var_Grad { w.var = std::make_shared(this->var->clone()); if (!this->grad->empty()) w.grad = std::make_shared(this->grad->clone()); + if (!this->var32->empty()) + w.var32 = std::make_shared(this->var32->clone()); return w; } @@ -294,6 +287,13 @@ class Weight : public Var_Grad { */ void applyGradient(double lr) { var->add_i(*grad.get(), -lr); } + /** + * @brief Apply the gradient to the weight with updated gradient + * @param[in] updated_grad gradient tensor which is updated in optimizer + * it might be different data type with gradient in weight. .eg : FP32 + */ + void applyGradient(double lr, Tensor &updated_grad); + /** * @brief Check if the gradient is supposed to be clipped by global norm with * the given max_norm value @@ -316,6 +316,16 @@ class Weight : public Var_Grad { return clip_by_global_norm > epsilon; } + /** + * @brief Check if the variable type is not full precision + * + * @return true if it is not full precsion + * @return false otherwise + */ + bool isMixedPrecision() const { + return ((var->getDataType() != ml::train::TensorDim::DataType::FP32)); + } + /** * @brief clip the gradient value based on the given global norm * @@ -326,6 +336,32 @@ class Weight : public Var_Grad { grad->multiply_i(clip_by_global_norm / (global_norm + epsilon)); } + /** + * @brief Get the variable FP32 tensor (by reference) + * + * @return Tensor Variable FP32 tensor + */ + Tensor &getVariableFP32Ref() { return *var32.get(); } + + /** + * @brief Quantize var32 to var + * + */ + void quantizeWeight(); + + /** + * @brief set loss scale + * param[in] scale + * + */ + void setLossScale(float scale) { loss_scale = scale; }; + + /** + * @brief get loss scale + * + */ + const float getLossScale() { return loss_scale; }; + private: static constexpr float epsilon = 1e-6; /**< epsilon for zero comparison */ static constexpr float epsilon_decay = @@ -337,7 +373,8 @@ class Weight : public Var_Grad { float clip_by_global_norm; /**< constant factor to clip gradient by L2 norm */ unsigned int output_axis; float loss_scale; - std::vector opt_vars; /**< optimizer variables */ + std::vector + opt_vars; /**< optimizer variables : We assume it is always full-precsion*/ std::shared_ptr var32; /** diff --git a/nntrainer/utils/base_properties.h b/nntrainer/utils/base_properties.h index 259637a6d9..b7e3b38942 100644 --- a/nntrainer/utils/base_properties.h +++ b/nntrainer/utils/base_properties.h @@ -705,6 +705,23 @@ class TensorFormat final : public EnumProperty { set(value); }; }; + +// /** +// * @brief trainable property, use this to set and check how if certain layer is +// * trainable +// * +// */ +// class Trainable : public nntrainer::Property { +// public: +// /** +// * @brief Construct a new Trainable object +// * +// */ +// Trainable(bool val = true) : nntrainer::Property(val) {} +// static constexpr const char *key = "trainable"; +// using prop_tag = bool_prop_tag; +// }; + } // namespace props } // namespace nntrainer diff --git a/packaging/nntrainer.spec b/packaging/nntrainer.spec index 36ba371d22..2f1dc57f68 100644 --- a/packaging/nntrainer.spec +++ b/packaging/nntrainer.spec @@ -65,6 +65,13 @@ %define neon_support -Denable-neon=false %endif # arch aarch64 +%ifarch x86_64 +%define enable_avx 1 +%define avx_support -Denable-avx=true +%else +%define avx_support -Denable-avx=false +%endif # arch aarch64 + Name: nntrainer Summary: Software framework for training neural networks @@ -410,7 +417,7 @@ meson --buildtype=plain --prefix=%{_prefix} --sysconfdir=%{_sysconfdir} \ %{enable_reduce_tolerance} %{configure_subplugin_install_path} %{enable_debug} \ -Dml-api-support=enabled -Denable-nnstreamer-tensor-filter=enabled \ -Denable-nnstreamer-tensor-trainer=enabled -Denable-capi=enabled \ - %{fp16_support} %{neon_support} build + %{fp16_support} %{neon_support} %{avx_support} build ninja -C build %{?_smp_mflags} @@ -563,6 +570,10 @@ cp -r result %{buildroot}%{_datadir}/nntrainer/unittest/ %{_includedir}/nntrainer/util_simd_neon.h %endif +%if 0%{?enable_avx} +%{_includedir}/nntrainer/blas_avx.h +%endif + %files devel-static %{_libdir}/libnntrainer*.a %exclude %{_libdir}/libcapi*.a diff --git a/packaging/unittest_layers.tar.gz b/packaging/unittest_layers.tar.gz index 7a435aadf4..3bd488a0a2 100644 Binary files a/packaging/unittest_layers.tar.gz and b/packaging/unittest_layers.tar.gz differ diff --git a/packaging/unittest_models_v3.tar.gz b/packaging/unittest_models_v3.tar.gz index abc7ead4a4..49a1f1b2ad 100644 Binary files a/packaging/unittest_models_v3.tar.gz and b/packaging/unittest_models_v3.tar.gz differ diff --git a/test/include/nntrainer_test_util.h b/test/include/nntrainer_test_util.h index 74eef4abaa..8e16b6a9f4 100644 --- a/test/include/nntrainer_test_util.h +++ b/test/include/nntrainer_test_util.h @@ -347,6 +347,29 @@ float mse(Ta *A, Tb *B, uint32_t size) { return mse; } +/** + * @brief calculate mean squared errer + * + * @param A const prediction data + * @param B const reference data + * @param size data size + * @return mean squared errer value + */ +template +float mse(const Ta *A, const Tb *B, uint32_t size) { + float pred; + float ref; + float mse_error = 0; + for (uint32_t i = 0; i < size; i++) { + pred = A[i]; + ref = B[i]; + float diff = pred - ref; + mse_error += pow(diff, 2); + } + float mse = mse_error / size; + return mse; +} + /** * @brief A helper struct for performing static_cast operations on types. * diff --git a/test/input_gen/genModelTests_v2.py b/test/input_gen/genModelTests_v2.py index a56f437785..422c737487 100644 --- a/test/input_gen/genModelTests_v2.py +++ b/test/input_gen/genModelTests_v2.py @@ -11,6 +11,7 @@ import math from recorder_v2 import record_v2, inspect_file, _rand_like import torch +from torch import autocast class ReduceMeanLast(torch.nn.Module): def __init__(self): @@ -307,6 +308,40 @@ def forward(self, inputs, labels): loss = self.loss(out, labels[0]) return out, loss +class LinearMixedPrecision(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(3, 10) + self.loss = torch.nn.MSELoss() + + def forward(self, inputs, labels): + with autocast(device_type='cuda', dtype=torch.float16): + input=inputs[0].to('cuda') + label=labels[0].to('cuda') + out = self.fc(input) + return out + + def getOptimizer(self): + return torch.optim.Adam(self.parameters(), lr=0.1) + +class LinearMixedPrecisionNaNSGD(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc0 = torch.nn.Linear(1, 1) + self.fc1 = torch.nn.Linear(1, 1) + self.loss = torch.nn.MSELoss() + + def forward(self, inputs, labels): + with autocast(device_type='cuda', dtype=torch.float16): + input=inputs[0].to('cuda') + label=labels[0].to('cuda') + out = self.fc0(input) + out = self.fc1(out) + return out + + def getOptimizer(self): + return torch.optim.SGD(self.parameters(), lr=0.1) + if __name__ == "__main__": record_v2( ReduceMeanLast(), @@ -537,5 +572,28 @@ def forward(self, inputs, labels): name="non_trainable_fc_idx3" ) - # Function to check the created golden test file + fc_mixed_training = LinearMixedPrecision() + record_v2( + fc_mixed_training, + iteration=3, + input_dims=[(1,3)], + input_dtype=[float], + label_dims=[(1,10)], + name="fc_mixed_training", + optimizer=fc_mixed_training.getOptimizer() + ) + + fc_mixed_training_nan_sgd = LinearMixedPrecisionNaNSGD() + record_v2( + fc_mixed_training_nan_sgd, + iteration=5, + input_dims=[(1,1)], + input_dtype=[float], + label_dims=[(1,1)], + name="fc_mixed_training_nan_sgd", + optimizer=fc_mixed_training_nan_sgd.getOptimizer() + ) + +# Function to check the created golden test file inspect_file("non_trainable_fc_idx3.nnmodelgolden") + diff --git a/test/input_gen/recorder_v2.py b/test/input_gen/recorder_v2.py index 9bc219c767..6b8f42ff88 100644 --- a/test/input_gen/recorder_v2.py +++ b/test/input_gen/recorder_v2.py @@ -12,6 +12,8 @@ import random import torch # torch used here is torch==1.9.1 import numpy as np +import torch.cuda.amp as amp +from torch import autocast from transLayer_v2 import params_translated @@ -29,13 +31,31 @@ def _get_writer(file): - def write_fn(items): + def write_fn(items, type = 'float32'): if not isinstance(items, (list, tuple)): items = [items] for item in items: - np.array([item.numel()], dtype="int32").tofile(file) - item.detach().cpu().numpy().tofile(file) + print(item.numel(), " -0-----") + print(item) + np.array([item.numel()], dtype='int32').tofile(file) + a=np.array(item.detach().cpu(), dtype=type) + a.tofile(file) + print(a.dtype) + + return items + + return write_fn + +def _get_writer_mixed(file): + def write_fn(items, num_type = 'int32', type = 'float32'): + if not isinstance(items, (list, tuple)): + items = [items] + + for item in items: + np.array([item.numel()], dtype=num_type).tofile(file) + a=np.array(item.detach().cpu(), dtype=type) + a.tofile(file) return items @@ -96,14 +116,65 @@ def record_iteration(write_fn): norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 0.0001) optimizer.step() + def record_iteration_with_amp(write_fn, inputs, labels, is_nan, scaler): + model_= model.cuda() + + print(inputs[0], " inputs inside") + output = model_(inputs[0], labels[0]) + + print("model output type: ",output.dtype) + + with autocast(device_type='cuda', dtype=torch.float16): + l=model_.loss(output, labels[0].to('cuda')) + + optimizer.zero_grad() + + scaler.scale(l).backward() + print("Gradient ---------------") + for param in model_.parameters(): + print (param.grad) + mask = torch.isnan(param.grad) or torch.isinf(param.grad) + check_nan = mask.int() + if check_nan.sum().item(): + is_nan = True + else: + is_nan = False + + + if not is_nan: + print("------------------------------- not nan") + write_fn(output,'int32','float32') + return output, is_nan + with open(file_name, "wb") as f: # write number of iterations + print("iteration : ", iteration) np.array([iteration], dtype="int32").tofile(f) - write_fn = _get_writer(f) - for _ in range(iteration): - record_iteration(write_fn) - + write_fn = _get_writer_mixed(f) + for i in range(iteration): + if input_label_reader != None: + inputs, labels = input_label_reader(input_dims, label_dims, input_dtype) + else: + inputs = _rand_like(input_dims, dtype=input_dtype if input_dtype is not None else float) + labels = _rand_like(label_dims, dtype=float) + print("inputs ==============") + write_fn(inputs,'int32', 'float32') + print("labels ==============") + write_fn(labels, 'int32', 'float32') + is_nan = True; + print("=========================== ", i) + scaler = amp.GradScaler() + print("weights ==============") + write_fn(list(t for _, t in params_translated(model)),'int16','float16') + print("\n\n") + while(is_nan): + print( "before is_nan_", is_nan) + output,is_nan_ = record_iteration_with_amp(write_fn, inputs, labels, is_nan, scaler) + is_nan = is_nan_ + print( "after is_nan_", is_nan) + scaler.step(optimizer) + scaler.update() ## # @brief inpsect if file is created correctly diff --git a/test/nntrainer_test_util.cpp b/test/nntrainer_test_util.cpp index bcc33e40c8..5777bb75b2 100644 --- a/test/nntrainer_test_util.cpp +++ b/test/nntrainer_test_util.cpp @@ -332,6 +332,7 @@ void sizeCheckedReadTensor(nntrainer::Tensor &t, std::ifstream &file, nntrainer::checkedRead(file, (char *)&sz, sizeof(unsigned)); } else if (t.getDataType() == ml::train::TensorDim::DataType::FP16) { #ifdef ENABLE_FP16 + // This needs to be fixed. sz is always unsinged int type. nntrainer::checkedRead(file, (char *)&sz, sizeof(_FP16)); #else throw std::invalid_argument("Error: enable-fp16 is not enabled"); diff --git a/test/unittest/layers/layers_golden_tests.cpp b/test/unittest/layers/layers_golden_tests.cpp index 56d591019b..4312dc57b7 100644 --- a/test/unittest/layers/layers_golden_tests.cpp +++ b/test/unittest/layers/layers_golden_tests.cpp @@ -48,7 +48,8 @@ static const std::string getGoldenPath(const std::string &file_name) { static InitLayerContext createInitContext(Layer *layer, const std::string &input_shape_str, - std::array tensor_type) { + std::array tensor_type, + ml::train::ExecutionMode mode) { struct shape_parser_ : Property { using prop_tag = dimension_prop_tag; }; @@ -68,7 +69,7 @@ createInitContext(Layer *layer, const std::string &input_shape_str, } InitLayerContext context({parsed.begin(), parsed.end()}, {true}, false, - "golden_test", "", 0.0, tensor_type); + "golden_test", "", 0.0, tensor_type, 1.0, mode); layer->finalize(context); for (auto &dim : context.getMutableInputDimensions()) { @@ -84,7 +85,9 @@ createInitContext(Layer *layer, const std::string &input_shape_str, } static TensorPacks prepareTensors(const InitLayerContext &context, - std::ifstream &file) { + std::ifstream &file, + std::array tensor_type, + std::string l_type = "") { auto allocate_inouts = [&file](const auto &dims) { std::vector vg; vg.reserve(dims.size()); @@ -119,14 +122,29 @@ static TensorPacks prepareTensors(const InitLayerContext &context, return vg; }; - auto allocate_weights = [&file](const auto &specs) { + auto allocate_weights = [&file, tensor_type, l_type](const auto &specs) { std::vector weights; + std::vector weights_; weights.reserve(specs.size()); - + weights_.reserve(specs.size()); for (auto &spec : specs) { - weights.emplace_back(spec, true); - sizeCheckedReadTensor(weights.back().getVariableRef(), file, - weights.back().getName()); + if (istrequal(l_type, "batch_normalization")) { + WeightSpec spec_ = spec; + std::get<0>(spec_).setDataType( + str_converter:: + from_string(tensor_type[1])); + weights_.emplace_back(spec_, true); + sizeCheckedReadTensor(weights_.back().getVariableRef(), file, + weights_.back().getName()); + + weights.emplace_back(spec, true); + weights.back().getVariableRef().copyData( + weights_.back().getVariableRef()); + } else { + weights.emplace_back(spec, true); + sizeCheckedReadTensor(weights.back().getVariableRef(), file, + weights.back().getName()); + } weights.back().getGradientRef().setZero(); } return weights; @@ -156,7 +174,7 @@ static RunLayerContext prepareRunContext(const TensorPacks &packs) { }; auto rc = - RunLayerContext("golden", true, 0.0f, false, create_view(weights), + RunLayerContext("golden", true, 0.0f, false, 1.0, create_view(weights), create_view(ins), create_view(outs), create_view(tensors)); auto num_outputs = rc.getNumOutputs(); @@ -172,7 +190,9 @@ static RunLayerContext prepareRunContext(const TensorPacks &packs) { static void compareRunContext(RunLayerContext &rc, std::ifstream &file, bool skip_grad, bool skip_deriv, - bool dropout_match, bool skip_cos_sim) { + bool dropout_match, bool skip_cos_sim, + std::array tensor_type, + std::string layer_type) { file.seekg(0, std::ios::beg); auto compare_percentage_tensors = [](const Tensor &t1, const Tensor &t2, @@ -280,21 +300,22 @@ static void compareRunContext(RunLayerContext &rc, std::ifstream &file, auto compare_tensors = [&file, compare_percentage_tensors]( unsigned length, auto tensor_getter, auto pred, bool skip_compare, bool skip_cos_sim, - const std::string &name, + const std::string &name, TensorDim::DataType d_type, unsigned int match_percentage = 100) { for (unsigned i = 0; i < length; ++i) { if (!pred(i)) { continue; } const auto &tensor = tensor_getter(i); - auto answer = tensor.clone(); + auto answer = tensor.clone(d_type); sizeCheckedReadTensor(answer, file, name + " at " + std::to_string(i)); if (skip_compare) { continue; } - EXPECT_TRUE(compare_percentage_tensors(tensor, answer, match_percentage, - skip_cos_sim)) + EXPECT_TRUE(compare_percentage_tensors( + tensor.getDataType() != d_type ? tensor.clone(d_type) : tensor, answer, + match_percentage, skip_cos_sim)) << name << " at " << std::to_string(i); } }; @@ -313,29 +334,47 @@ static void compareRunContext(RunLayerContext &rc, std::ifstream &file, compare_tensors( rc.getNumWeights(), [&rc](unsigned idx) -> const auto & { return rc.getWeight(idx); }, - always_read, skip_compare, skip_cos_sim, "initial_weights"); + always_read, skip_compare, skip_cos_sim, "initial_weights", + str_converter::from_string(tensor_type[1])); + + TensorDim::DataType d_type = + str_converter::from_string(tensor_type[2]); + if (layer_type == "embedding") { + d_type = TensorDim::DataType::FP32; + } + compare_tensors( rc.getNumInputs(), [&rc](unsigned idx) -> const auto & { return rc.getInput(idx); }, - always_read, !skip_compare, skip_cos_sim, "inputs"); + always_read, !skip_compare, skip_cos_sim, "inputs", d_type); compare_tensors( rc.getNumOutputs(), [&rc](unsigned idx) -> const auto & { return rc.getOutput(idx); }, - always_read, !skip_compare, skip_cos_sim, "outputs", match_percentage); + always_read, !skip_compare, skip_cos_sim, "outputs", + str_converter::from_string(tensor_type[2]), + match_percentage); compare_tensors( rc.getNumWeights(), [&rc](unsigned idx) -> const auto & { return rc.getWeightGrad(idx); }, - only_read_trainable, skip_grad, skip_cos_sim, "gradients"); + only_read_trainable, skip_grad, skip_cos_sim, "gradients", + str_converter::from_string(tensor_type[2])); compare_tensors( rc.getNumWeights(), [&rc](unsigned idx) -> const auto & { return rc.getWeight(idx); }, - always_read, !skip_compare, skip_cos_sim, "weights"); + always_read, !skip_compare, skip_cos_sim, "weights", + str_converter::from_string(tensor_type[2])); compare_tensors( rc.getNumInputs(), [&rc](unsigned idx) -> const auto & { return rc.getOutgoingDerivative(idx); }, - always_read, skip_deriv, skip_cos_sim, "derivatives", match_percentage); + always_read, skip_deriv, skip_cos_sim, "derivatives", d_type, + match_percentage); } LayerGoldenTest::~LayerGoldenTest() {} @@ -385,9 +424,14 @@ TEST_P(LayerGoldenTest, run) { getGoldenPath(std::get<3>(GetParam())), std::ios::in | std::ios::binary); auto &input_dims = std::get<2>(GetParam()); + ml::train::ExecutionMode mode = ml::train::ExecutionMode::TRAIN; + if (shouldForwardWithInferenceMode()) + mode = ml::train::ExecutionMode::INFERENCE; + auto ic = - createInitContext(layer.get(), input_dims, {format, type_w, type_a}); - auto tensors = prepareTensors(ic, golden_file); + createInitContext(layer.get(), input_dims, {format, type_w, type_a}, mode); + auto tensors = + prepareTensors(ic, golden_file, {format, type_w, type_a}, layer->getType()); auto rc = prepareRunContext(tensors); bool skip_calc_grad = shouldSkipCalcGrad(); @@ -425,7 +469,8 @@ TEST_P(LayerGoldenTest, run) { } compareRunContext(rc, golden_file, skip_calc_grad, skip_calc_deriv, - dropout_compare_60_percent, skip_cos_sim); + dropout_compare_60_percent, skip_cos_sim, + {format, type_w, type_a}, layer->getType()); EXPECT_TRUE(true); // stub test for tcm } diff --git a/test/unittest/layers/unittest_layer_node.cpp b/test/unittest/layers/unittest_layer_node.cpp index 3b41f02f30..37287f7ce5 100644 --- a/test/unittest/layers/unittest_layer_node.cpp +++ b/test/unittest/layers/unittest_layer_node.cpp @@ -131,7 +131,7 @@ TEST(nntrainer_LayerNode, finalize_05_n) { nntrainer::createLayerNode(nntrainer::IdentityLayer::type)); EXPECT_NO_THROW(lnode->setProperty({"input_shape=1:1:1", "name=abc"})); EXPECT_NO_THROW(lnode->finalize()); - EXPECT_NO_THROW(lnode->configureRunContext({}, {&input}, {}, {})); + EXPECT_NO_THROW(lnode->configureRunContext({}, {&input}, {}, {}, 1.0)); EXPECT_THROW(lnode->finalize(), std::runtime_error); } @@ -298,7 +298,7 @@ TEST(nntrainer_LayerNode, setWeights_02_n) { EXPECT_NO_THROW(lnode = nntrainer::createLayerNode(nntrainer::IdentityLayer::type)); EXPECT_NO_THROW(lnode->setProperty({"input_shape=1:1:1", "name=abc"})); - EXPECT_NO_THROW(lnode->configureRunContext({&weight}, {&input}, {}, {})); + EXPECT_NO_THROW(lnode->configureRunContext({&weight}, {&input}, {}, {}, 1.0)); EXPECT_THROW(lnode->setWeights(new_weights), std::runtime_error); } diff --git a/test/unittest/layers/unittest_layers_batch_normalization.cpp b/test/unittest/layers/unittest_layers_batch_normalization.cpp index e4e1b06f55..34272290a5 100644 --- a/test/unittest/layers/unittest_layers_batch_normalization.cpp +++ b/test/unittest/layers/unittest_layers_batch_normalization.cpp @@ -68,8 +68,7 @@ auto bn_basic_channels_inference_w16a16 = LayerGoldenTestParamType( auto bn_basic_width_training_w16a16 = LayerGoldenTestParamType( nntrainer::createLayer, {}, "2:1:1:10", - "bn_width_training_w16a16.nnlayergolden", bn_option, "nchw", "fp16", - "fp16"); + "bn_width_training_w16a16.nnlayergolden", bn_option, "nchw", "fp16", "fp16"); auto bn_basic_width_inference_w16a16 = LayerGoldenTestParamType( nntrainer::createLayer, {}, "2:1:1:10", @@ -77,8 +76,8 @@ auto bn_basic_width_inference_w16a16 = LayerGoldenTestParamType( "fp16", "fp16"); GTEST_PARAMETER_TEST(BatchNormalization16, LayerGoldenTest, - ::testing::Values(bn_basic_channels_training_w16a16, - bn_basic_channels_inference_w16a16, + ::testing::Values(bn_basic_channels_inference_w16a16, + bn_basic_channels_training_w16a16, bn_basic_width_training_w16a16, bn_basic_width_inference_w16a16)); #endif diff --git a/test/unittest/layers/unittest_layers_convolution2d.cpp b/test/unittest/layers/unittest_layers_convolution2d.cpp index 724c79079b..92d9c593e7 100644 --- a/test/unittest/layers/unittest_layers_convolution2d.cpp +++ b/test/unittest/layers/unittest_layers_convolution2d.cpp @@ -198,3 +198,185 @@ GTEST_PARAMETER_TEST( conv2d_mb_valid_drop_last, conv2d_sb_no_overlap, conv2d_mb_no_overlap, conv2d_sb_1x1_kernel, conv2d_mb_1x1_kernel, conv2d_sb_dilation, conv2d_mb_dilation, conv2d_sb_same_dilation, conv2d_mb_same_dilation)); + +#ifdef ENABLE_FP16 +auto conv2d_sb_minimum_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + {"filters=3", "kernel_size=2,2"}, "1:1:4:4", + "conv2d_sb_minimum_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_mb_minimum_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + {"filters=3", "kernel_size=2,2"}, "3:1:4:4", + "conv2d_mb_minimum_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_sb_same_remain_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + {"filters=2", "kernel_size=3,3", "padding=same"}, "1:1:4:4", + "conv2d_sb_same_remain_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_mb_same_remain_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + {"filters=2", "kernel_size=3,3", "padding=same"}, "3:1:4:4", + "conv2d_mb_same_remain_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_sb_same_uneven_remain_1_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=2", + "kernel_size=3,3", + "stride=2,2", + "padding=same", + }, + "1:3:4:4", "conv2d_sb_same_uneven_remain_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_sb_same_uneven_remain_2_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=2", + "kernel_size=3,3", + "stride=2,2", + "padding=0,1,0,1", + }, + "1:3:4:4", "conv2d_sb_same_uneven_remain_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_mb_same_uneven_remain_1_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=2", + "kernel_size=3,3", + "stride=2,2", + "padding=same", + }, + "3:3:4:4", "conv2d_mb_same_uneven_remain_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_mb_same_uneven_remain_2_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=2", + "kernel_size=3,3", + "stride=2,2", + "padding=0,1,0,1", + }, + "3:3:4:4", "conv2d_mb_same_uneven_remain_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_sb_valid_drop_last_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=2", + "kernel_size=3,3", + "stride=2,2", + "padding=valid", + }, + "1:3:7:7", "conv2d_sb_valid_drop_last_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_mb_valid_drop_last_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=2", + "kernel_size=3,3", + "stride=2,2", + "padding=valid", + }, + "3:3:7:7", "conv2d_mb_valid_drop_last_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_sb_no_overlap_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + {"filters=3", "kernel_size=2,2", "stride=3,3"}, "1:2:5:5", + "conv2d_sb_no_overlap_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_mb_no_overlap_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=3", + "kernel_size=2,2", + "stride=3,3", + }, + "3:2:5:5", "conv2d_mb_no_overlap_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_sb_1x1_kernel_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + {"filters=3", "kernel_size=1,1", "stride=2,2"}, "1:2:5:5", + "conv2d_sb_1x1_kernel_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_mb_1x1_kernel_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=3", + "kernel_size=1,1", + "stride=2,2", + }, + "3:2:5:5", "conv2d_mb_1x1_kernel_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_sb_dilation_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=2", + "kernel_size=3,3", + "dilation=2,2", + }, + "1:3:11:11", "conv2d_sb_dilation_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_mb_dilation_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=2", + "kernel_size=3,3", + "dilation=2,2", + }, + "3:3:11:11", "conv2d_mb_dilation_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_sb_same_dilation_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=2", + "kernel_size=3,3", + "padding=same", + "dilation=2,2", + }, + "1:3:11:11", "conv2d_sb_same_dilation_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +auto conv2d_mb_same_dilation_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, + { + "filters=2", + "kernel_size=3,3", + "padding=same", + "dilation=2,2", + }, + "3:3:11:11", "conv2d_mb_same_dilation_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16"); + +GTEST_PARAMETER_TEST( + Convolution2D16, LayerGoldenTest, + ::testing::Values(conv2d_sb_minimum_w16a16, conv2d_mb_minimum_w16a16, + conv2d_sb_same_remain_w16a16, conv2d_mb_same_remain_w16a16, + conv2d_sb_same_uneven_remain_1_w16a16, + conv2d_sb_same_uneven_remain_2_w16a16, + conv2d_mb_same_uneven_remain_1_w16a16, + conv2d_mb_same_uneven_remain_2_w16a16, + conv2d_sb_valid_drop_last_w16a16, + conv2d_mb_valid_drop_last_w16a16, + conv2d_sb_no_overlap_w16a16, conv2d_mb_no_overlap_w16a16, + conv2d_sb_1x1_kernel_w16a16, conv2d_mb_1x1_kernel_w16a16, + conv2d_sb_dilation_w16a16, conv2d_mb_dilation_w16a16, + conv2d_sb_same_dilation_w16a16, + conv2d_mb_same_dilation_w16a16)); +#endif diff --git a/test/unittest/models/meson.build b/test/unittest/models/meson.build index 7166fc41ff..3f17369f94 100644 --- a/test/unittest/models/meson.build +++ b/test/unittest/models/meson.build @@ -1,4 +1,5 @@ test_name = 'unittest_models' +mixed_test_name = 'unittest_mixed_models' test_target = [] @@ -11,6 +12,30 @@ models_targets = [ # disable temperally ] +mixed_test_targets = [ + 'models_test_utils.cpp', + 'models_golden_test.cpp', + 'unittest_models_mixed_precision.cpp', +] + +if get_option('enable-fp16') + mixed_exe = executable( + mixed_test_name, + mixed_test_targets, + include_directories: include_directories('.'), + dependencies: [ + nntrainer_test_main_deps, nntrainer_ccapi_dep + ], + install: get_option('enable-test'), + install_dir: application_install_dir + ) + + test(mixed_test_name, mixed_exe, + args: '--gtest_output=xml:@0@/@1@.xml'.format(meson.build_root(), mixed_test_name), + timeout: test_timeout + ) +endif + test_target += models_targets exe = executable( test_name, diff --git a/test/unittest/models/models_test_utils.cpp b/test/unittest/models/models_test_utils.cpp index 741e008994..ac956d479b 100644 --- a/test/unittest/models/models_test_utils.cpp +++ b/test/unittest/models/models_test_utils.cpp @@ -50,8 +50,41 @@ static sharedConstTensors toSharedTensors(const std::vector &ts) { static void verify(const nntrainer::Tensor &actual, const nntrainer::Tensor &expected, const std::string &error_msg) { + bool equal = false; + + if (actual.getDataType() == ml::train::TensorDim::DataType::FP32 && + expected.getDataType() == ml::train::TensorDim::DataType::FP32) { + equal = (actual == expected); + if (!equal) { + float mseError = mse(actual.getData(), + expected.getData(), actual.size()); + if (mseError > 10 - 4) { + equal = false; + } else { + equal = true; + } + } + } + +#ifdef ENABLE_FP16 + if (!equal) { + if (actual.getDataType() == ml::train::TensorDim::DataType::FP16 && + expected.getDataType() == ml::train::TensorDim::DataType::FP16) { + float mseError = mse<_FP16>(actual.getData<_FP16>(), + expected.getData<_FP16>(), actual.size()); + if (mseError > 10 - 2) { + equal = false; + } else { + equal = true; + } + } + } +#endif + + if (!equal) { + nntrainer::Tensor diff = actual.subtract(expected); + const float *diff_data = diff.getData(); - if (actual != expected) { std::cout << "============================================================\n"; std::cout << "\033[1;33m" << error_msg << "\033[0m\n"; @@ -60,8 +93,6 @@ static void verify(const nntrainer::Tensor &actual, << " - " << expected; if (actual.getDim() == expected.getDim()) { - nntrainer::Tensor diff = actual.subtract(expected); - const float *diff_data = diff.getData(); std::cout << "\033[1;33mdifference\033[0m " << diff; std::cout << "number of data: " << diff.size() << std::endl; std::cout << "\033[4;33mMAX DIFF: " @@ -119,6 +150,12 @@ class IterationForGolden { } Tensor &t = rc.getWeight(i); + + if (t.getDataType() != ml::train::TensorDim::DataType::FP32) { + Tensor &t32 = rc.getWeightFP32(i); + weights32.push_back(t32); + } + weights.push_back(t); expected_weights.push_back(t.clone()); } @@ -158,6 +195,10 @@ class IterationForGolden { } else { for (unsigned int i = 0; i < weights.size(); ++i) { weights.at(i).fill(expected_weights.at(i)); + if (iteration == 0 && + weights.at(i).getDataType() != ml::train::TensorDim::DataType::FP32) + weights32.at(i).fill( + weights.at(i).clone(ml::train::TensorDim::DataType::FP32)); } } @@ -174,6 +215,7 @@ class IterationForGolden { std::vector inputs; std::vector labels; std::vector weights; + std::vector weights32; std::vector expected_weights; std::vector expected_outputs; }; diff --git a/test/unittest/models/unittest_models_mixed_precision.cpp b/test/unittest/models/unittest_models_mixed_precision.cpp new file mode 100644 index 0000000000..04c1495491 --- /dev/null +++ b/test/unittest/models/unittest_models_mixed_precision.cpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Jijoong Moon + * + * @file unittest_models_mixed_precision.cpp + * @date 3 May 2024 + * @brief unittest models to cover mixed precision + * @see https://github.com/nnstreamer/nntrainer + * @author Jijoong Moon + * @bug No known bugs except for NYI items + */ + +#include + +#include + +#include +#include +#include + +#include + +using namespace nntrainer; + +static std::unique_ptr fc_mixed_training() { + std::unique_ptr nn(new NeuralNetwork()); + nn->setProperty( + {"batch_size=1", "model_tensor_type=FP16-FP16", "loss_scale=65536"}); + + auto graph = makeGraph({ + {"input", {"name=in", "input_shape=1:1:3"}}, + {"Fully_connected", {"name=fc", "input_layers=in", "unit=10"}}, + {"mse", {"name=loss", "input_layers=fc"}}, + }); + for (auto &node : graph) { + nn->addLayer(node); + } + + nn->setOptimizer(ml::train::createOptimizer( + "adam", {"learning_rate = 0.1", "torch_ref=true"})); + + return nn; +} + +static std::unique_ptr fc_mixed_training_nan_sgd() { + std::unique_ptr nn(new NeuralNetwork()); + nn->setProperty( + {"batch_size=1", "model_tensor_type=FP16-FP16", "loss_scale=65536"}); + + auto graph = makeGraph({ + {"input", {"name=in", "input_shape=1:1:1"}}, + {"Fully_connected", {"name=fc0", "input_layers=in", "unit=1"}}, + {"Fully_connected", {"name=fc1", "input_layers=fc0", "unit=1"}}, + {"mse", {"name=loss", "input_layers=fc1"}}, + }); + for (auto &node : graph) { + nn->addLayer(node); + } + + nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"})); + + return nn; +} + +GTEST_PARAMETER_TEST( + MixedPrecision, nntrainerModelTest, + ::testing::ValuesIn({ + mkModelTc_V2(fc_mixed_training, "fc_mixed_training", + ModelTestOption::ALL_V2), + mkModelTc_V2(fc_mixed_training_nan_sgd, "fc_mixed_training_nan_sgd", + ModelTestOption::ALL_V2), + }), + [](const testing::TestParamInfo &info) + -> const auto & { return std::get<1>(info.param); }); diff --git a/test/unittest/unittest_nntrainer_tensor.cpp b/test/unittest/unittest_nntrainer_tensor.cpp index 94aa01836d..d5b6a028f9 100644 --- a/test/unittest/unittest_nntrainer_tensor.cpp +++ b/test/unittest/unittest_nntrainer_tensor.cpp @@ -4704,6 +4704,30 @@ TEST(nntrainer_Tensor, inv_sqrt_i_uncontiguous_p) { } } +/** + * @brief fp16 tensor has NaN + */ +TEST(nntrainer_Tensor, is_valid_01) { + size_t batch = 1; + size_t channel = 3; + size_t height = 4; + size_t width = 5; + + nntrainer::Tensor input( + {batch, + channel, + height, + width, + {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}}, + true, nntrainer::Tensor::Initializer::ZEROS); + + EXPECT_EQ(input.isValid(), true); + + input.setValue(0, 0, 0, 0, std::nan("1")); + + EXPECT_EQ(input.isValid(), false); +} + int main(int argc, char **argv) { int result = -1; diff --git a/test/unittest/unittest_nntrainer_tensor_fp16.cpp b/test/unittest/unittest_nntrainer_tensor_fp16.cpp index 2b0d9c040d..58455757c5 100644 --- a/test/unittest/unittest_nntrainer_tensor_fp16.cpp +++ b/test/unittest/unittest_nntrainer_tensor_fp16.cpp @@ -6196,6 +6196,34 @@ TEST(nntrainer_Tensor, dequantize_06_p) { EXPECT_EQ(output, answer3); } +/** + * @brief fp16 tensor has NaN + */ +TEST(nntrainer_Tensor, is_valid_01) { + size_t batch = 1; + size_t channel = 3; + size_t height = 4; + size_t width = 5; + + nntrainer::Tensor input( + {batch, + channel, + height, + width, + {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}}, + true, nntrainer::Tensor::Initializer::ZEROS); + + EXPECT_EQ(input.isValid(), true); + + input.setValue(0, 0, 0, 0, std::nan("1")); + + EXPECT_EQ(input.isValid(), false); + + input.setValue(0, 0, 0, 0, std::numeric_limits::infinity()); + + EXPECT_EQ(input.isValid(), false); +} + GTEST_API_ int main(int argc, char **argv) { int result = -1; diff --git a/test/unittest/unittest_nntrainer_tensor_neon_fp16.cpp b/test/unittest/unittest_nntrainer_tensor_neon_fp16.cpp index 2c81bdcbd4..ffde1f9273 100644 --- a/test/unittest/unittest_nntrainer_tensor_neon_fp16.cpp +++ b/test/unittest/unittest_nntrainer_tensor_neon_fp16.cpp @@ -1303,6 +1303,38 @@ TEST(nntrainer_Tensor, inv_sqrt_i_p) { EXPECT_EQ(flag, true); } +/** + * @brief fp16 tensor has NaN + */ +TEST(nntrainer_Tensor, is_valid_01) { + size_t batch = 1; + size_t channel = 3; + size_t height = 4; + size_t width = 5; + + nntrainer::Tensor input( + {batch, + channel, + height, + width, + {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}}, + true, nntrainer::Tensor::Initializer::ZEROS); + + EXPECT_EQ(input.isValid(), true); + + input.setValue(0, 0, 0, 0, std::nan("1")); + + EXPECT_EQ(input.isValid(), false); + + input.setValue(0, 0, 0, 0, std::numeric_limits::infinity()); + + EXPECT_EQ(input.isValid(), false); + + input.setValue(0, 0, 0, 0, 1); + + EXPECT_EQ(input.isValid(), true); +} + GTEST_API_ int main(int argc, char **argv) { int result = -1; diff --git a/tools/package_android.sh b/tools/package_android.sh index 6e02cc23d2..5fc7ba8754 100755 --- a/tools/package_android.sh +++ b/tools/package_android.sh @@ -17,14 +17,14 @@ if [ ! -d builddir ]; then #default value of openblas num threads is 1 for android #enable-tflite-interpreter=false is just temporally until ci system is stabel #enable-opencl=true will compile OpenCL related changes or remove this option to exclude OpenCL compilations. - meson builddir -Dplatform=android -Dopenblas-num-threads=1 -Denable-tflite-interpreter=false -Denable-tflite-backbone=false -Denable-fp16=true -Denable-neon=true -Domp-num-threads=1 -Denable-opencl=true + meson builddir -Dplatform=android -Dopenblas-num-threads=1 -Denable-tflite-interpreter=false -Denable-tflite-backbone=false -Denable-fp16=true -Denable-neon=true -Domp-num-threads=1 -Denable-opencl=true -Denable-avx=false else echo "warning: $TARGET/builddir has already been taken, this script tries to reconfigure and try building" pushd builddir #default value of openblas num threads is 1 for android #enable-tflite-interpreter=false is just temporally until ci system is stabel #enable-opencl=true will compile OpenCL related changes or remove this option to exclude OpenCL compilations. - meson configure -Dplatform=android -Dopenblas-num-threads=1 -Denable-tflite-interpreter=false -Denable-tflite-backbone=false -Denable-fp16=true -Denable-neon=true -Domp-num-threads=1 -Denable-opencl=true + meson configure -Dplatform=android -Dopenblas-num-threads=1 -Denable-tflite-interpreter=false -Denable-tflite-backbone=false -Denable-fp16=true -Denable-neon=true -Domp-num-threads=1 -Denable-opencl=true -Denable-avx=false meson --wipe popd fi