From 047adc322acd21ca86cdbfd75ec669c67c3b6db2 Mon Sep 17 00:00:00 2001 From: "jijoong.moon" Date: Mon, 29 Jul 2024 10:38:31 +0900 Subject: [PATCH] [Mixed Precision] Fix mixed precsion to use Tensor V2 This PR includes fixes to use TensorV2 Resolves: **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: jijoong.moon --- api/ccapi/include/model.h | 15 +++---- meson.build | 21 ++++++---- nntrainer/app_context.cpp | 8 ++-- nntrainer/graph/network_graph.cpp | 11 ++--- nntrainer/graph/network_graph.h | 11 ++++- nntrainer/layers/bn_layer.cpp | 18 ++++----- nntrainer/layers/conv2d_layer.cpp | 2 + nntrainer/layers/layer_context.h | 6 ++- nntrainer/layers/lstm.cpp | 24 +++++------ nntrainer/layers/pooling2d_layer.cpp | 8 ++-- nntrainer/models/neuralnet.cpp | 1 - nntrainer/tensor/blas_interface.cpp | 3 +- nntrainer/tensor/char_tensor.h | 5 +++ nntrainer/tensor/float_tensor.cpp | 6 +-- nntrainer/tensor/float_tensor.h | 2 +- nntrainer/tensor/half_tensor.cpp | 6 +-- nntrainer/tensor/half_tensor.h | 2 +- nntrainer/tensor/manager.cpp | 12 +++--- nntrainer/tensor/manager.h | 13 +++--- nntrainer/tensor/meson.build | 2 +- nntrainer/tensor/tensor.cpp | 7 ++-- nntrainer/tensor/tensor_base.h | 2 +- nntrainer/tensor/tensor_wrap_specs.h | 6 +-- nntrainer/tensor/uint_tensor.h | 5 +++ nntrainer/tensor/weight.cpp | 9 +---- nntrainer/tensor/weight.h | 40 +++++++++---------- packaging/nntrainer.spec | 4 +- test/unittest/unittest_nntrainer_tensor.cpp | 4 +- .../unittest_nntrainer_tensor_fp16.cpp | 2 +- .../unittest_nntrainer_tensor_neon_fp16.cpp | 32 --------------- 30 files changed, 139 insertions(+), 148 deletions(-) diff --git a/api/ccapi/include/model.h b/api/ccapi/include/model.h index ef6303e6da..e8d185bbdb 100644 --- a/api/ccapi/include/model.h +++ b/api/ccapi/include/model.h @@ -188,13 +188,14 @@ class Model { * @details This function accepts vector of properties in the format - * { std::string property_name, void * property_val, ...} */ - virtual int train(const std::vector &values = {}, - std::function stop_cb = - [](void *stop_user_data) { return false; }, - void *stop_user_data = nullptr, - std::function epoch_complete_cb = - [](void *epoch_user_data) { return false; }, - void *epoch_user_data = nullptr) = 0; + virtual int train( + const std::vector &values = {}, + std::function stop_cb = + [](void *stop_user_data) { return false; }, + void *stop_user_data = nullptr, + std::function epoch_complete_cb = + [](void *epoch_user_data) { return false; }, + void *epoch_user_data = nullptr) = 0; /** * @brief Run Model train with callback function by user diff --git a/meson.build b/meson.build index 06e4e7da28..3efebb198a 100644 --- a/meson.build +++ b/meson.build @@ -71,13 +71,20 @@ warning_c_flags = [ arch = host_machine.cpu_family() if get_option('enable-avx') - extra_defines += '-DUSE_AVX=1' - if get_option('platform') == 'tizen' - add_project_arguments(['-mavx2'], language: ['c','cpp']) - else - add_project_arguments(['-march=native'], language: ['c','cpp']) - endif - message('-march=native added for AVX hardware acceleration.') + if get_option('platform') != 'android' + if arch == 'x86_64' or arch == 'x86' + has_avx2 = cc.has_argument('-mavx2') + if (has_avx2) + extra_defines += '-DUSE_AVX=1' + add_project_arguments(['-march=native'], language: ['c','cpp']) + message('-march=native added for AVX hardware acceleration.') + endif + else + message('This arch does not support avx2') + endif + else + message('avx2 is not supported') + endif endif if get_option('enable-fp16') diff --git a/nntrainer/app_context.cpp b/nntrainer/app_context.cpp index b6ebb34461..c0e00cff49 100644 --- a/nntrainer/app_context.cpp +++ b/nntrainer/app_context.cpp @@ -550,12 +550,14 @@ std::vector AppContext::registerPluggableFromDirectory(const std::string &base_path) { DIR *dir = opendir(base_path.c_str()); - NNTR_THROW_IF(dir == nullptr, std::invalid_argument) - << func_tag << "failed to open the directory: " << base_path; + std::vector keys; + if (dir == NULL) { + closedir(dir); + return keys; + } struct dirent *entry; - std::vector keys; while ((entry = readdir(dir)) != NULL) { if (endswith(entry->d_name, solib_suffix)) { if (endswith(entry->d_name, layerlib_suffix)) { diff --git a/nntrainer/graph/network_graph.cpp b/nntrainer/graph/network_graph.cpp index eee11c49f6..6e7a135d87 100644 --- a/nntrainer/graph/network_graph.cpp +++ b/nntrainer/graph/network_graph.cpp @@ -341,6 +341,9 @@ void NetworkGraph::applyGradients( /** * @note the weights whose gradient are to be clipped by global norm will * be clipped at once at the end of iteration and applied then. + * For those weights where mixed precision is uesed, their gradient + * updates might be delayed until they confirm whether their loss scales + * are appropeiate. */ continue; } @@ -487,18 +490,12 @@ bool NetworkGraph::backwarding( } } /** apply the gradient with the above global norm */ - std::cout << "======================================= update gradient " - << std::endl; for (auto w : lazy_weights) { - std::cout << w->getName() << " : "; lazy_apply_grad_op(*w, iteration); } nan_count++; - std::cout << "====================================== update gradient finished" - << std::endl; /** @todo : handle as property : growth_interval : default --> 2000 */ - if (nan_count > 2000) { float scale = (*iter_)->getRunContext().getLossScale(); /** @todo growth_factor : default --> 2.0 */ @@ -1647,7 +1644,7 @@ void NetworkGraph::requestOptimizerVariable( w->setOptimizerVariables(tensor_manager->requestWeightOptimizerVariables( dims, w->getName(), ":opt", TensorLifespan::MAX_LIFESPAN, w->isGradientClipByGlobalNorm(), w->isMixedPrecision(), - Tensor::Initializer::ZEROS)); + Initializer::ZEROS)); } } } diff --git a/nntrainer/graph/network_graph.h b/nntrainer/graph/network_graph.h index 8756e97775..2eb6e4ac27 100644 --- a/nntrainer/graph/network_graph.h +++ b/nntrainer/graph/network_graph.h @@ -58,7 +58,12 @@ class NetworkGraph { /** * @brief Constructor of NeuralNetwork Graph Class * @param[in] enable_swap enable memory swap for tensor + * @param[in] mode execution mode (default ExecutionMode::TRAIN) * @param[in] swap_path memory swap file path when the swap is enabled + * @param[in] tensor_format define tensor format. One of NCHW and NHWC + * (default NCHW) + * @param[in] tensor_type It says weight type and activation type (default + * FP32-FP32) */ NetworkGraph(bool enable_swap, ExecutionMode mode = ExecutionMode::TRAIN, const std::string &swap_path = "", unsigned int lookahead = 0, @@ -207,8 +212,9 @@ class NetworkGraph { /** * @brief backwarding the network graph * @param[in] iteration current iteration number + * @param[in] forwarding_op operation for the forwarding * @param[in] backwarding_op operation for the backwarding - * @param[in] apply_grad_clip_op operation for applying the clip gradients + * @param[in] lazy_apply_grad_op operation for applying the lazy gradients */ bool backwarding( int iteration, @@ -496,7 +502,8 @@ class NetworkGraph { std::unordered_map profile_keys; /**< profile keys based on the layer type */ std::vector - lazy_weights; /**< weights with global norm based clipping enabled */ + lazy_weights; /**< weights with delayed grad update, e.g., gradient + clipping, loss scaling */ bool is_clip_grad; unsigned int nan_count; diff --git a/nntrainer/layers/bn_layer.cpp b/nntrainer/layers/bn_layer.cpp index dda76ae39f..c5802291d8 100644 --- a/nntrainer/layers/bn_layer.cpp +++ b/nntrainer/layers/bn_layer.cpp @@ -118,12 +118,12 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) { 1.0f, bias_decay, "beta", true); wt_idx[BNParams::mu_b] = - context.requestTensor(dim, "moviing_mean_backup", Tensor::Initializer::NONE, - false, TensorLifespan::ITERATION_LIFESPAN); + context.requestTensor(dim, "moviing_mean_backup", Initializer::NONE, false, + TensorLifespan::ITERATION_LIFESPAN); - wt_idx[BNParams::var_b] = context.requestTensor( - dim, "moviing_variance_backup", Tensor::Initializer::NONE, false, - TensorLifespan::ITERATION_LIFESPAN); + wt_idx[BNParams::var_b] = + context.requestTensor(dim, "moviing_variance_backup", Initializer::NONE, + false, TensorLifespan::ITERATION_LIFESPAN); /** * caches the deviation -> input - avg(input) @@ -137,8 +137,8 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) { } wt_idx[BNParams::deviation] = - context.requestTensor(in_dim_, "deviation", Tensor::Initializer::NONE, - false, TensorLifespan::ITERATION_LIFESPAN); + context.requestTensor(in_dim_, "deviation", Initializer::NONE, false, + TensorLifespan::ITERATION_LIFESPAN); /** caches the inverse standard deviation */ wt_idx[BNParams::invstd] = context.requestTensor(dim, "invstd", Initializer::NONE, false, @@ -150,8 +150,8 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) { * as the output of this layer need not be stored all the time. */ wt_idx[BNParams::t_full] = - context.requestTensor(in_dim_, "tensor_full", Tensor::Initializer::NONE, - false, TensorLifespan::CALC_DERIV_LIFESPAN); + context.requestTensor(in_dim_, "tensor_full", Initializer::NONE, false, + TensorLifespan::CALC_DERIV_LIFESPAN); /** * caches variance + epsilon as well. */ diff --git a/nntrainer/layers/conv2d_layer.cpp b/nntrainer/layers/conv2d_layer.cpp index 70647add5b..8425f1eb56 100644 --- a/nntrainer/layers/conv2d_layer.cpp +++ b/nntrainer/layers/conv2d_layer.cpp @@ -242,6 +242,8 @@ static void im2col(const Tensor &in, const TensorDim &kdim, unsigned int base_im_h = 0; int patch_height_end = eff_k_height + hs; /// map the patch to a single line looping through channel + // We need to optimize this padding & copy. May be use multi threads, or + // SIMD for (unsigned int c = 0; c < channel; ++c) { for (int h = hs; h < patch_height_end; h += dilation[0]) { if (h < 0 || in_height <= h) { diff --git a/nntrainer/layers/layer_context.h b/nntrainer/layers/layer_context.h index 4feabb86de..50e042e045 100644 --- a/nntrainer/layers/layer_context.h +++ b/nntrainer/layers/layer_context.h @@ -50,6 +50,10 @@ class InitLayerContext { * @param name name * @param prefix_ prefix * @param max_norm max norm + * @param tensor_type array including tensor format and weight, activation + * type. + * @param loss_scale loss scale value for mixed precision training + * @param mode execution mode. */ InitLayerContext( const std::vector &dim, @@ -220,7 +224,7 @@ class InitLayerContext { * start from 0 and will always be incremental. */ unsigned int requestWeight(const TensorDim &dim, const TensorDim &dim_g, - const Tensor::Initializer init, + const Initializer init, const WeightRegularizer reg, const float reg_const, const float decay, const std::string &name, bool trainable = true, unsigned int out_axis = 3) { diff --git a/nntrainer/layers/lstm.cpp b/nntrainer/layers/lstm.cpp index 2788188e4a..7fe9931d82 100644 --- a/nntrainer/layers/lstm.cpp +++ b/nntrainer/layers/lstm.cpp @@ -512,16 +512,16 @@ void LSTMLayer::finalize(InitLayerContext &context) { const TensorDim hidden_state_dim(batch_size, 1, max_timestep, unit, activation_tensor_type); - wt_idx[LSTMParams::hidden_state] = context.requestTensor( - hidden_state_dim, "hidden_state", Tensor::Initializer::NONE, true, - TensorLifespan::ITERATION_LIFESPAN); + wt_idx[LSTMParams::hidden_state] = + context.requestTensor(hidden_state_dim, "hidden_state", Initializer::NONE, + true, TensorLifespan::ITERATION_LIFESPAN); // cell_state_dim : [ batch_size, 1, max_timestep, unit ] const TensorDim cell_state_dim(batch_size, 1, max_timestep, unit, activation_tensor_type); - wt_idx[LSTMParams::cell_state] = context.requestTensor( - cell_state_dim, "cell_state", Tensor::Initializer::NONE, true, - TensorLifespan::ITERATION_LIFESPAN); + wt_idx[LSTMParams::cell_state] = + context.requestTensor(cell_state_dim, "cell_state", Initializer::NONE, true, + TensorLifespan::ITERATION_LIFESPAN); // ifgo_dim : [ batch_size, 1, max_timestep, NUM_GATE * unit ] const TensorDim ifgo_dim(batch_size, 1, max_timestep, NUM_GATE * unit, @@ -594,18 +594,18 @@ void LSTMLayer::finalize(InitLayerContext &context) { // reverse_ifgo_dim : [ batch_size, 1, max_timestep, NUM_GATE * unit ] const TensorDim reverse_ifgo_dim(batch_size, 1, max_timestep, NUM_GATE * unit, activation_tensor_type); - wt_idx[LSTMParams::reverse_ifgo] = context.requestTensor( - reverse_ifgo_dim, "reverse_ifgo", Tensor::Initializer::NONE, true, - TensorLifespan::ITERATION_LIFESPAN); + wt_idx[LSTMParams::reverse_ifgo] = + context.requestTensor(reverse_ifgo_dim, "reverse_ifgo", Initializer::NONE, + true, TensorLifespan::ITERATION_LIFESPAN); } if (dropout_rate > epsilon) { // dropout_mask_dim = [ batch, 1, time_iteration, unit ] const TensorDim dropout_mask_dim(batch_size, 1, max_timestep, unit, activation_tensor_type); - wt_idx[LSTMParams::dropout_mask] = context.requestTensor( - dropout_mask_dim, "dropout_mask", Tensor::Initializer::NONE, false, - TensorLifespan::ITERATION_LIFESPAN); + wt_idx[LSTMParams::dropout_mask] = + context.requestTensor(dropout_mask_dim, "dropout_mask", Initializer::NONE, + false, TensorLifespan::ITERATION_LIFESPAN); } if (context.getActivationDataType() == TensorDim::DataType::FP32) { diff --git a/nntrainer/layers/pooling2d_layer.cpp b/nntrainer/layers/pooling2d_layer.cpp index 676a0a6128..cf71ae8ae3 100644 --- a/nntrainer/layers/pooling2d_layer.cpp +++ b/nntrainer/layers/pooling2d_layer.cpp @@ -126,15 +126,15 @@ void Pooling2DLayer::finalize(InitLayerContext &context) { auto helper_dim = in_dim; helper_dim.setDataType(ml::train::TensorDim::DataType::FP32); pool_helper_idx = - context.requestTensor(helper_dim, "helper_idx", Tensor::Initializer::NONE, - false, TensorLifespan::ITERATION_LIFESPAN); + context.requestTensor(helper_dim, "helper_idx", Initializer::NONE, false, + TensorLifespan::ITERATION_LIFESPAN); pool_helper_size.resize(helper_dim.batch() * helper_dim.channel()); } else { auto helper_dim = out_dim; helper_dim.setDataType(ml::train::TensorDim::DataType::FP32); pool_helper_idx = - context.requestTensor(helper_dim, "helper_idx", Tensor::Initializer::NONE, - false, TensorLifespan::ITERATION_LIFESPAN); + context.requestTensor(helper_dim, "helper_idx", Initializer::NONE, false, + TensorLifespan::ITERATION_LIFESPAN); } } diff --git a/nntrainer/models/neuralnet.cpp b/nntrainer/models/neuralnet.cpp index 12d6ecd0d7..3833489369 100644 --- a/nntrainer/models/neuralnet.cpp +++ b/nntrainer/models/neuralnet.cpp @@ -1160,7 +1160,6 @@ int NeuralNetwork::train_run( auto epochs = getEpochs(); ml_logd("[NNTrainer] Starts training. Current epoch: %d. Total epochs: %d.", epoch_idx + 1, getEpochs()); - epoch_idx = 0; for (epoch_idx = epoch_idx + 1; epoch_idx <= epochs; ++epoch_idx) { if (stop_cb(stop_user_data)) { --epoch_idx; diff --git a/nntrainer/tensor/blas_interface.cpp b/nntrainer/tensor/blas_interface.cpp index ebf3b3478b..8dc509960b 100644 --- a/nntrainer/tensor/blas_interface.cpp +++ b/nntrainer/tensor/blas_interface.cpp @@ -874,8 +874,7 @@ void scopy(const unsigned int N, const float *X, const int incX, float *Y, #ifdef BLAS_NUM_THREADS openblas_set_num_threads(BLAS_NUM_THREADS); #endif - // cblas_scopy(N, (float*)(X), incX, (float*)(Y), incY); - // replace cblas scopy with raw temporary. + // cblas_scopy(N, X, incX, Y, incY); for (unsigned int i = 0; i < N; ++i) Y[i * incY] = X[i * incX]; #else diff --git a/nntrainer/tensor/char_tensor.h b/nntrainer/tensor/char_tensor.h index 366d11c148..4bf8563162 100644 --- a/nntrainer/tensor/char_tensor.h +++ b/nntrainer/tensor/char_tensor.h @@ -231,6 +231,11 @@ class CharTensor : public TensorBase { * @return std::string of tensor data type (QINT8) */ std::string getStringDataType() const override { return "QINT8"; } + + /** + * @copydoc Tensor::isValid() + */ + bool isValid() const override { return true; }; // NYI }; } // namespace nntrainer diff --git a/nntrainer/tensor/float_tensor.cpp b/nntrainer/tensor/float_tensor.cpp index 92249ec1a6..22600f32cb 100644 --- a/nntrainer/tensor/float_tensor.cpp +++ b/nntrainer/tensor/float_tensor.cpp @@ -150,7 +150,7 @@ void FloatTensor::setZero() { // sscal(size(), 0, getData(), 1); /// @note we cannot use sscal, when we set zero. if the data is inf or /// NaN, then the inf or NaN still remain. - memset(getData(), 0, sizeof(float) * size()); + memset((float *)getData(), 0, sizeof(float) * size()); } else { /// @todo implement apply_i // apply_i([](float val) -> float { return 0; }); @@ -1210,8 +1210,8 @@ void FloatTensor::apply_broadcast( return apply_broadcast_util(m, v_func, output, this->computeBroadcastInfo(m)); } -bool Tensor::isValid() const { - return is_valid(dim.getDataLen(), Tdatatype::FP32, getData()); +bool FloatTensor::isValid() const { + return is_valid(dim.getDataLen(), Tdatatype::FP32, (float *)getData()); } } // namespace nntrainer diff --git a/nntrainer/tensor/float_tensor.h b/nntrainer/tensor/float_tensor.h index 0829505364..23681fc339 100644 --- a/nntrainer/tensor/float_tensor.h +++ b/nntrainer/tensor/float_tensor.h @@ -511,7 +511,7 @@ class FloatTensor : public TensorBase { /** * @copydoc Tensor::isValid() */ - bool Tensor::isValid() const; + bool isValid() const override; }; } // namespace nntrainer diff --git a/nntrainer/tensor/half_tensor.cpp b/nntrainer/tensor/half_tensor.cpp index b934d8706e..55e072ed74 100644 --- a/nntrainer/tensor/half_tensor.cpp +++ b/nntrainer/tensor/half_tensor.cpp @@ -149,7 +149,7 @@ void HalfTensor::setZero() { // sscal(size(), 0, (_FP16 *)getData(), 1); /// @note we cannot use sscal, when we set zero. if the data is inf or /// NaN, then the inf or NaN still remain. - memset(getData<_FP16>(), 0, sizeof(_FP16) * size()); + memset((_FP16 *)getData(), 0, sizeof(_FP16) * size()); } else { /// @todo implement apply_i // apply_i<_FP16>([](_FP16 val) -> _FP16 { return 0; }); @@ -1176,8 +1176,8 @@ void HalfTensor::apply_broadcast_util( } } -bool Tensor::isValid() const { - return is_valid(dim.getDataLen(), Tdatatype::FP16, getData<_FP16>()); +bool HalfTensor::isValid() const { + return is_valid(dim.getDataLen(), Tdatatype::FP16, (_FP16 *)getData()); } } // namespace nntrainer diff --git a/nntrainer/tensor/half_tensor.h b/nntrainer/tensor/half_tensor.h index d79afc0805..206a8482de 100644 --- a/nntrainer/tensor/half_tensor.h +++ b/nntrainer/tensor/half_tensor.h @@ -502,7 +502,7 @@ class HalfTensor : public TensorBase { /** * @copydoc Tensor::isValid() */ - bool Tensor::isValid() const; + bool isValid() const override; }; } // namespace nntrainer diff --git a/nntrainer/tensor/manager.cpp b/nntrainer/tensor/manager.cpp index 37fd92074e..e454e51119 100644 --- a/nntrainer/tensor/manager.cpp +++ b/nntrainer/tensor/manager.cpp @@ -434,7 +434,7 @@ std::vector Manager::requestWeights( */ grad = tensor_pool.requestOrExtend(shared_name + Var_Grad::grad_suffix, dim_g, grad_exec_order, grad_ls, - Tensor::Initializer::ZEROS); + Initializer::ZEROS); if (var->getDataType() != ml::train::TensorDim::DataType::FP32) { TensorDim var32_dim(dim_v); @@ -444,7 +444,7 @@ std::vector Manager::requestWeights( var32 = weight_pool.requestOrExtend(shared_name + ":var32", var32_dim, var32_exec_order, var_ls, - Tensor::Initializer::ZEROS); + Initializer::ZEROS); } } } else { @@ -461,8 +461,8 @@ std::vector Manager::requestWeights( // if (Weight::isGradientClipByGlobalNorm(clip_by_global_norm)) // is_wgrad = false; grad = tensor_pool.request(name + Var_Grad::grad_suffix, dim_g, - grad_exec_order, grad_ls, - Tensor::Initializer::ZEROS, is_wgrad); + grad_exec_order, grad_ls, Initializer::ZEROS, + is_wgrad); if (var->getDataType() != ml::train::TensorDim::DataType::FP32) { TensorDim var32_dim(dim_v); var32_dim.setDataType(ml::train::TensorDim::DataType::FP32); @@ -470,7 +470,7 @@ std::vector Manager::requestWeights( var32_exec_order.push_back(TensorPool::PERSIST_END_ORDER); var32 = weight_pool.request(name + ":var32", var32_dim, var32_exec_order, - var_ls, Tensor::Initializer::ZEROS); + var_ls, Initializer::ZEROS); } } } @@ -690,7 +690,7 @@ bool Manager::isSecondLastAccess(const std::string &name, std::vector Manager::requestWeightOptimizerVariables( const std::vector &dims, const std::string &name, const std::string &suffix, const TensorLifespan &lifespan, bool is_grad_clip, - bool is_mixed_precision, Tensor::Initializer initializer) { + bool is_mixed_precision, Initializer initializer) { std::vector ret; ret.reserve(dims.size()); diff --git a/nntrainer/tensor/manager.h b/nntrainer/tensor/manager.h index 43ab364d50..281b6ebe4e 100644 --- a/nntrainer/tensor/manager.h +++ b/nntrainer/tensor/manager.h @@ -141,18 +141,19 @@ class Manager { /** * @brief Constructor of Manager */ - Manager(bool enable_swap, const std::string &swap_path = "", + Manager(bool enable_swap_, const std::string &swap_path = "", unsigned int lookahead = 0, const std::string tensor_format_ = "NCHW", const std::string tensor_dtype_ = "FP32-FP32", ExecutionMode exec_mode_ = ExecutionMode::TRAIN) : - weight_pool(enable_swap, swap_path, "weight_pool"), - tensor_pool(enable_swap && (exec_mode_ == ExecutionMode::TRAIN), swap_path, + weight_pool(enable_swap_, swap_path, "weight_pool"), + tensor_pool(enable_swap_ && (exec_mode_ == ExecutionMode::TRAIN), swap_path, "tensor_pool"), enable_optimizations(true), swap_lookahead(lookahead), tensor_format(tensor_format_), tensor_dtype(split(tensor_dtype_, getRegex("\\-"))), - exec_mode(exec_mode_) {} + exec_mode(exec_mode_), + enable_swap(enable_swap_) {} /** * @brief Construct a new Manager object (deleted) @@ -228,7 +229,7 @@ class Manager { const std::vector &dims, const std::string &name, const std::string &suffix, const TensorLifespan &lifespan, bool is_grad_clip, bool is_mixed_type, - Tensor::Initializer initializer = Tensor::Initializer::NONE); + Initializer initializer = Initializer::NONE); /** * @brief Create tensors with the given spec @@ -537,6 +538,8 @@ class Manager { ExecutionMode exec_mode; + bool enable_swap; + /** * @brief Finalize the given tensor pool * diff --git a/nntrainer/tensor/meson.build b/nntrainer/tensor/meson.build index e5682633ac..a9d05043c0 100644 --- a/nntrainer/tensor/meson.build +++ b/nntrainer/tensor/meson.build @@ -48,7 +48,7 @@ tensor_headers = [ arch = host_machine.cpu_family() -if get_option('enable-avx') +if get_option('enable-avx') and get_option('platform') != 'android' tensor_sources += 'blas_avx.cpp' tensor_headers += 'blas_avx.h' endif diff --git a/nntrainer/tensor/tensor.cpp b/nntrainer/tensor/tensor.cpp index ac0c2c7f90..932918d971 100644 --- a/nntrainer/tensor/tensor.cpp +++ b/nntrainer/tensor/tensor.cpp @@ -1024,10 +1024,11 @@ Tensor Tensor::clone() const { Tensor Tensor::clone(ml::train::TensorDim::DataType type) const { if (getDataType() == type) return clone(); - - Tensor output(getName(), getFormat(), type); + TensorDim dim = getDim(); + dim.setDataType(type); + Tensor output(dim, true); output.copyData(*this); - output.name = name; + output.setName(getName()); return output; } diff --git a/nntrainer/tensor/tensor_base.h b/nntrainer/tensor/tensor_base.h index 7d4e564689..b147a502bf 100644 --- a/nntrainer/tensor/tensor_base.h +++ b/nntrainer/tensor/tensor_base.h @@ -680,7 +680,7 @@ class TensorBase { /** * @copydoc Tensor::isValid() */ - bool isValid() const { return true; }; + virtual bool isValid() const = 0; static constexpr float epsilon = 1e-5; diff --git a/nntrainer/tensor/tensor_wrap_specs.h b/nntrainer/tensor/tensor_wrap_specs.h index cd8d9e7082..a1c6f68bca 100644 --- a/nntrainer/tensor/tensor_wrap_specs.h +++ b/nntrainer/tensor/tensor_wrap_specs.h @@ -75,9 +75,9 @@ enum class TensorLifespan { * regularizer_constant, decay, clip gradient constant, need_gradient property, * name, output axis of the tensor object and loss Scale Factor, is_mixed. */ -typedef std::tuple +typedef std::tuple WeightSpec; /** diff --git a/nntrainer/tensor/uint_tensor.h b/nntrainer/tensor/uint_tensor.h index 7c09ad93de..75f00d92a5 100644 --- a/nntrainer/tensor/uint_tensor.h +++ b/nntrainer/tensor/uint_tensor.h @@ -259,6 +259,11 @@ template class UIntTensor : public TensorBase { else throw std::runtime_error("unsupported type"); } + + /** + * @copydoc Tensor::isValid() + */ + bool isValid() const override { return true; }; // NYI }; /****** Alias for UIntTensors ******/ diff --git a/nntrainer/tensor/weight.cpp b/nntrainer/tensor/weight.cpp index ca5328b1fc..778138285e 100644 --- a/nntrainer/tensor/weight.cpp +++ b/nntrainer/tensor/weight.cpp @@ -31,7 +31,7 @@ Weight::Weight(const TensorDim &dim, const Initializer init, output_axis(axis), loss_scale(loss_scale_), is_mixed(is_mixed_) { - if (init == Tensor::Initializer::NONE) + if (init == Initializer::NONE) throw std::invalid_argument("Weight initializer cannot be none"); if (regularizer == WeightRegularizer::UNKNOWN) throw std::invalid_argument("Weight regularizer unknown"); @@ -73,7 +73,7 @@ Weight::Weight(const TensorDim &dim_v, const TensorDim &dim_g, output_axis(axis), loss_scale(loss_scale_), is_mixed(is_mixed_) { - if (init == Tensor::Initializer::NONE) + if (init == Initializer::NONE) throw std::invalid_argument("Weight initializer cannot be none"); if (regularizer == WeightRegularizer::UNKNOWN) throw std::invalid_argument("Weight regularizer unknown"); @@ -84,9 +84,6 @@ Weight::Weight(const TensorDim &dim_v, const TensorDim &dim_g, if (train && dim_v.getDataType() != ml::train::TensorDim::DataType::FP32) { TensorDim var32_dim(dim_v); var32_dim.setDataType(ml::train::TensorDim::DataType::FP32); - std::string var32_suffix = ":fp32"; - std::string var32_name = name + var32_suffix; - var32 = std::make_shared(var32_dim, alloc_now_, init, var32_name); } else { var32 = std::make_shared(var32_name); @@ -137,8 +134,6 @@ void Weight::applyGradient(double lr, Tensor &updated_grad) { updated_grad.getDataType() == ml::train::TensorDim::DataType::FP32 && var->getDataType() != ml::train::TensorDim::DataType::FP32) { var32->add_i(updated_grad, -lr); - std::cout << var32->getName() << " --------------------------" << std::endl; - var32->print(std::cout); quantizeWeight(); return; } else { diff --git a/nntrainer/tensor/weight.h b/nntrainer/tensor/weight.h index f12f9597ec..4db4b106ed 100644 --- a/nntrainer/tensor/weight.h +++ b/nntrainer/tensor/weight.h @@ -60,14 +60,14 @@ class Weight : public Var_Grad { * @param alloc_now The memory for the weight tensors be allocated upon init * @param name Name for this weight */ - explicit Weight( - const TensorDim &dim, - const Tensor::Initializer init = Tensor::Initializer::XAVIER_UNIFORM, - const WeightRegularizer reg = WeightRegularizer::NONE, - const float reg_const = 1.0f, const float decay = 0.0f, - const float clip_by_global_norm = 0.0f, bool ng = true, - bool alloc_now = false, std::string name = "", unsigned int axis = 3, - float loss_scale_ = 1.0, bool is_mixed = false); + explicit Weight(const TensorDim &dim, + const Initializer init = Initializer::XAVIER_UNIFORM, + const WeightRegularizer reg = WeightRegularizer::NONE, + const float reg_const = 1.0f, const float decay = 0.0f, + const float clip_by_global_norm = 0.0f, bool ng = true, + bool alloc_now = false, std::string name = "", + unsigned int axis = 3, float loss_scale_ = 1.0, + bool is_mixed = false); /** * @brief Construct a new Weight object @@ -81,14 +81,14 @@ class Weight : public Var_Grad { * @param alloc_now The memory for the weight tensors be allocated upon init * @param name Name for this weight */ - explicit Weight( - const TensorDim &dim_v, const TensorDim &dim_g, - const Tensor::Initializer init = Tensor::Initializer::XAVIER_UNIFORM, - const WeightRegularizer reg = WeightRegularizer::NONE, - const float reg_const = 1.0f, const float decay = 0.0f, - const float clip_by_global_norm = 0.0f, bool ng = true, - bool alloc_now = false, std::string name = "", unsigned int axis = 3, - float loss_scale_ = 1.0, bool is_mixed = false); + explicit Weight(const TensorDim &dim_v, const TensorDim &dim_g, + const Initializer init = Initializer::XAVIER_UNIFORM, + const WeightRegularizer reg = WeightRegularizer::NONE, + const float reg_const = 1.0f, const float decay = 0.0f, + const float clip_by_global_norm = 0.0f, bool ng = true, + bool alloc_now = false, std::string name = "", + unsigned int axis = 3, float loss_scale_ = 1.0, + bool is_mixed = false); /** * @brief Construct a new Weight object @@ -116,7 +116,7 @@ class Weight : public Var_Grad { * * @param v Already created variable object * @param g Already created gradient object - * @param v32 Already created gradient object + * @param v32 Already created var32 object * @param n Name for this Weight * * @note This is primarily used to created wrapper of variable extracted from @@ -288,11 +288,7 @@ class Weight : public Var_Grad { /** * @brief Apply the gradient to the weight */ - void applyGradient(double lr) { - var->add_i(*grad.get(), -lr); - std::cout << var->getName() << " --------------------------" << std::endl; - var->print(std::cout); - } + void applyGradient(double lr) { var->add_i(*grad.get(), -lr); } /** * @brief Apply the gradient to the weight with updated gradient diff --git a/packaging/nntrainer.spec b/packaging/nntrainer.spec index 4793759cfc..b0837c51c6 100644 --- a/packaging/nntrainer.spec +++ b/packaging/nntrainer.spec @@ -1,7 +1,7 @@ # Execute gbs with --define "testcoverage 1" in case that you must get unittest coverage statistics %define use_cblas 1 %define nnstreamer_filter 1 -%define nnstreamer_trainer 0 +%define nnstreamer_trainer 1 %define nnstreamer_subplugin_path /usr/lib/nnstreamer %define use_gym 0 %define support_ccapi 1 @@ -574,6 +574,7 @@ cp -r result %{buildroot}%{_datadir}/nntrainer/unittest/ %{_includedir}/nntrainer/fp16.h %{_includedir}/nntrainer/util_simd.h %{_includedir}/nntrainer/loss_layer.h +%ifarch aarch64 %if 0%{?enable_fp16} %{_includedir}/nntrainer/util_simd_neon.h %{_includedir}/nntrainer/blas_neon.h @@ -592,7 +593,6 @@ cp -r result %{buildroot}%{_datadir}/nntrainer/unittest/ %{_includedir}/nntrainer/model_common_properties.h %{_includedir}/nntrainer/network_graph.h %{_includedir}/nntrainer/graph_core.h -%{_includedir}/nntrainer/graph_node.h %{_includedir}/nntrainer/manager.h %{_includedir}/nntrainer/basic_planner.h %{_includedir}/nntrainer/memory_planner.h diff --git a/test/unittest/unittest_nntrainer_tensor.cpp b/test/unittest/unittest_nntrainer_tensor.cpp index 3231347479..73cf6bd461 100644 --- a/test/unittest/unittest_nntrainer_tensor.cpp +++ b/test/unittest/unittest_nntrainer_tensor.cpp @@ -5464,7 +5464,7 @@ TEST(nntrainer_Tensor, inv_sqrt_i_uncontiguous_p) { } /** - * @brief fp16 tensor has NaN + * @brief float tensor has NaN */ TEST(nntrainer_Tensor, is_valid_01) { size_t batch = 1; @@ -5478,7 +5478,7 @@ TEST(nntrainer_Tensor, is_valid_01) { height, width, {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}}, - true, nntrainer::Tensor::Initializer::ZEROS); + true, nntrainer::Initializer::ZEROS); EXPECT_EQ(input.isValid(), true); diff --git a/test/unittest/unittest_nntrainer_tensor_fp16.cpp b/test/unittest/unittest_nntrainer_tensor_fp16.cpp index 1b75eddcc3..106ea484b9 100644 --- a/test/unittest/unittest_nntrainer_tensor_fp16.cpp +++ b/test/unittest/unittest_nntrainer_tensor_fp16.cpp @@ -6266,7 +6266,7 @@ TEST(nntrainer_Tensor, is_valid_01) { height, width, {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}}, - true, nntrainer::Tensor::Initializer::ZEROS); + true, nntrainer::Initializer::ZEROS); EXPECT_EQ(input.isValid(), true); diff --git a/test/unittest/unittest_nntrainer_tensor_neon_fp16.cpp b/test/unittest/unittest_nntrainer_tensor_neon_fp16.cpp index 7a8eea6ad9..8c0204c4c2 100644 --- a/test/unittest/unittest_nntrainer_tensor_neon_fp16.cpp +++ b/test/unittest/unittest_nntrainer_tensor_neon_fp16.cpp @@ -1571,38 +1571,6 @@ TEST(nntrainer_Tensor, inv_sqrt_i_p) { EXPECT_EQ(flag, true); } -/** - * @brief fp16 tensor has NaN - */ -TEST(nntrainer_Tensor, is_valid_01) { - size_t batch = 1; - size_t channel = 3; - size_t height = 4; - size_t width = 5; - - nntrainer::Tensor input( - {batch, - channel, - height, - width, - {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}}, - true, nntrainer::Tensor::Initializer::ZEROS); - - EXPECT_EQ(input.isValid(), true); - - input.setValue(0, 0, 0, 0, std::nan("1")); - - EXPECT_EQ(input.isValid(), false); - - input.setValue(0, 0, 0, 0, std::numeric_limits::infinity()); - - EXPECT_EQ(input.isValid(), false); - - input.setValue(0, 0, 0, 0, 1); - - EXPECT_EQ(input.isValid(), true); -} - GTEST_API_ int main(int argc, char **argv) { int result = -1;