diff --git a/api/ccapi/include/layer.h b/api/ccapi/include/layer.h index 7e76134c5b..7fcf1b06d6 100644 --- a/api/ccapi/include/layer.h +++ b/api/ccapi/include/layer.h @@ -354,21 +354,11 @@ Reshape(const std::vector &properties = {}) { /** * @brief Helper function to create addition layer */ -inline std::unique_ptr -Addition(const std::vector &properties = {}) { - return createLayer(LayerType::LAYER_ADDITION, properties); -} - -#ifdef ENABLE_OPENCL -/** - * @brief Helper function to create Addition layer for GPU - */ -inline std::unique_ptr -AdditionCL(const std::vector &properties = {}, - const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) { +inline std::unique_ptr Addition( + const std::vector &properties = {}, + const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) { return createLayer(LayerType::LAYER_ADDITION, properties, compute_engine); } -#endif /** * @brief Helper function to create concat layer diff --git a/nntrainer/layers/cl_layers/addition_layer_cl.cpp b/nntrainer/layers/cl_layers/addition_layer_cl.cpp index 48ea84d191..1cd9f1de41 100644 --- a/nntrainer/layers/cl_layers/addition_layer_cl.cpp +++ b/nntrainer/layers/cl_layers/addition_layer_cl.cpp @@ -3,7 +3,7 @@ * Copyright (C) 2024 Yash Singh * * @file addition_layer_cl.cpp - * @date 17 May 2024 + * @date 28 May 2024 * @see https://github.com/nnstreamer/nntrainer * @author Yash Singh yash.singh@samsung.com> * @bug No known bugs except for NYI items @@ -11,6 +11,7 @@ * implementation */ +#include #include #include #include @@ -19,15 +20,6 @@ #include -std::string addition_cl_kernel_ = - R"(__kernel void addition_cl(__global const float* input, __global float* output, const unsigned int size) { - #pragma printf_support - size_t idx = get_global_id(0); - if (idx < size) { - output[idx] = output[idx] + input[idx]; - } -})"; - namespace nntrainer { static constexpr size_t SINGLE_INOUT_IDX = 0; @@ -45,18 +37,11 @@ void AdditionLayerCL::forwarding(RunLayerContext &context, bool training) { if (!idx) { hidden_.copy(input_); } else { - // hidden_.add_i(input_); AddProcess(input_, hidden_, context); } } } -/** - * @brief declaring static kerinputnel objects - * - */ -opencl::Kernel AdditionLayerCL::kernel_addition; - void AdditionLayerCL::AddProcess(Tensor const &input, Tensor &result, RunLayerContext &context) { @@ -83,67 +68,6 @@ void AdditionLayerCL::AddProcess(Tensor const &input, Tensor &result, throw std::invalid_argument("Error: OpenCL fp16 is not supported yet."); } -void AdditionLayerCL::addition_cl(const float *input, float *res, - unsigned int size, RunLayerContext &context) { - - bool result = false; - do { - result = result = - context.clCreateKernel(addition_cl_kernel_, context.LayerKernel::ADD, - AdditionLayerCL::kernel_addition); - if (!result) { - break; - } - - size_t dim1_size = sizeof(float) * size; - opencl::Buffer inputA(context.context_inst_, dim1_size, true, nullptr); - - opencl::Buffer inOutRes(context.context_inst_, dim1_size, true, nullptr); - - result = inputA.WriteData(context.command_queue_inst_, input); - if (!result) { - break; - } - - result = inOutRes.WriteData(context.command_queue_inst_, res); - if (!result) { - break; - } - - result = AdditionLayerCL::kernel_addition.SetKernelArguments( - 0, &inputA, sizeof(cl_mem)); - if (!result) { - break; - } - - result = AdditionLayerCL::kernel_addition.SetKernelArguments( - 1, &inOutRes, sizeof(cl_mem)); - if (!result) { - break; - } - - result = AdditionLayerCL::kernel_addition.SetKernelArguments(2, &size, - sizeof(int)); - if (!result) { - break; - } - - const int work_groups_count[3] = {(int)size, 1, 1}; - const int work_group_size[3] = {32, 32, 1}; // test-value - result = context.command_queue_inst_.DispatchCommand( - AdditionLayerCL::kernel_addition, work_groups_count, work_group_size); - if (!result) { - break; - } - - result = inOutRes.ReadData(context.command_queue_inst_, res); - if (!result) { - break; - } - - } while (false); -} - void AdditionLayerCL::incremental_forwarding(RunLayerContext &context, unsigned int from, unsigned int to, bool training) { @@ -179,7 +103,6 @@ void AdditionLayerCL::incremental_forwarding(RunLayerContext &context, if (!idx) { hidden_step.copy(input_step); } else { - // hidden_step.add_i(input_step); AddProcess(input_step, hidden_step, context); } } diff --git a/nntrainer/layers/cl_layers/addition_layer_cl.h b/nntrainer/layers/cl_layers/addition_layer_cl.h index 78b9293351..b556746a7c 100644 --- a/nntrainer/layers/cl_layers/addition_layer_cl.h +++ b/nntrainer/layers/cl_layers/addition_layer_cl.h @@ -3,7 +3,7 @@ * Copyright (C) 2024 Yash Singh * * @file addition_layer_cl.h - * @date 17 May 2024 + * @date 28 May 2024 * @see https://github.com/nnstreamer/nntrainer * @author Yash Singh yash.singh@samsung.com> * @bug No known bugs except for NYI items @@ -17,8 +17,6 @@ #include #include -#include -#include #define CREATE_IF_EMPTY_DIMS(tensor, ...) \ do { \ @@ -78,11 +76,6 @@ class AdditionLayerCL : public Layer { */ void calcDerivative(RunLayerContext &context) override; - /** - * @brief declaring static kernel objects - */ - static opencl::Kernel kernel_addition; - /** * @brief Process data and dimensions for add operation used in addition layer * @param[in] input Tensor @@ -92,16 +85,6 @@ class AdditionLayerCL : public Layer { void AddProcess(Tensor const &input, Tensor &result, RunLayerContext &context); - /** - * @brief addition : sum of all input vectors - * @param[in] input float * for input - * @param[in] res float * for result/output - * @param[in] size number of elements in input vector - * @param[in] context RunLayerContext reference - */ - void addition_cl(const float *input, float *res, unsigned int size, - RunLayerContext &context); - /** * @copydoc bool supportBackwarding() const */ diff --git a/nntrainer/layers/cl_layers/blas_kernels.cpp b/nntrainer/layers/cl_layers/blas_kernels.cpp index b994afd731..341f73a244 100644 --- a/nntrainer/layers/cl_layers/blas_kernels.cpp +++ b/nntrainer/layers/cl_layers/blas_kernels.cpp @@ -51,12 +51,22 @@ std::string sgemm_cl_kernel_ = C[m * ldc + n] = c; })"; +std::string addition_cl_kernel_ = + R"(__kernel void addition_cl(__global const float* input, __global float* output, const unsigned int size) { + #pragma printf_support + size_t idx = get_global_id(0); + if (idx < size) { + output[idx] = output[idx] + input[idx]; + } + })"; + /** * @brief defining global kernel objects */ opencl::Kernel kernel_sgemv; opencl::Kernel kernel_sgemm; opencl::Kernel kernel_dot; +opencl::Kernel kernel_addition; void sgemv_cl(const float *matAdata, const float *vecXdata, float *vecYdata, unsigned int dim1, unsigned int dim2, unsigned int lda, @@ -299,4 +309,62 @@ void sgemm_cl(const float *A, const float *B, float *C, unsigned int M, } while (false); } +void addition_cl(const float *input, float *res, + unsigned int size, RunLayerContext &context) { + + bool result = false; + + do { + result = result = + context.clCreateKernel(addition_cl_kernel_, context.LayerKernel::ADD, + kernel_addition); + if (!result) { + break; + } + + size_t dim1_size = sizeof(float) * size; + opencl::Buffer inputA(context.context_inst_, dim1_size, true, nullptr); + + opencl::Buffer inOutRes(context.context_inst_, dim1_size, true, nullptr); + + result = inputA.WriteData(context.command_queue_inst_, input); + if (!result) { + break; + } + + result = inOutRes.WriteData(context.command_queue_inst_, res); + if (!result) { + break; + } + + result = kernel_addition.SetKernelArguments(0, &inputA, sizeof(cl_mem)); + if (!result) { + break; + } + + result = kernel_addition.SetKernelArguments(1, &inOutRes, sizeof(cl_mem)); + if (!result) { + break; + } + + result = kernel_addition.SetKernelArguments(2, &size, sizeof(int)); + if (!result) { + break; + } + + const int work_groups_count[3] = {(int)size, 1, 1}; + const int work_group_size[3] = {32, 32, 1}; // test-value + result = context.command_queue_inst_.DispatchCommand( + kernel_addition, work_groups_count, work_group_size); + if (!result) { + break; + } + + result = inOutRes.ReadData(context.command_queue_inst_, res); + if (!result) { + break; + } + + } while (false); +} } // namespace nntrainer diff --git a/nntrainer/layers/cl_layers/blas_kernels.h b/nntrainer/layers/cl_layers/blas_kernels.h index 57c82ef8ad..addb2d72f8 100644 --- a/nntrainer/layers/cl_layers/blas_kernels.h +++ b/nntrainer/layers/cl_layers/blas_kernels.h @@ -30,6 +30,7 @@ extern opencl::Kernel kernel_sgemm; extern opencl::Kernel kernel_sgemm_fp16; extern opencl::Kernel kernel_dot; extern opencl::Kernel kernel_dot_fp16; +extern opencl::Kernel kernel_addition; /** * @brief sgemv computation : Y = A*X + Y @@ -117,5 +118,15 @@ void sgemm_cl(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, unsigned int N, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc, RunLayerContext &context); +/** + * @brief addition : sum of all input vectors + * @param[in] input float * for input + * @param[in] res float * for result/output + * @param[in] size number of elements in input vector + * @param[in] context RunLayerContext reference + */ +void addition_cl(const float *input, float *res, unsigned int size, + RunLayerContext &context); + } // namespace nntrainer #endif /* __BLAS_KERNELS_H__ */