From a8f089fe1602379511d7a506bd32198e0ff99697 Mon Sep 17 00:00:00 2001 From: "yash.singh" Date: Thu, 23 May 2024 16:12:12 +0530 Subject: [PATCH 1/2] [GPU/OpenCL] Initial version of Addition Layer with OpenCL ops Added naive version of OpenCL implementation for Addition Layer. Incorporated kernel for ops used. Added unit test for addition_layer_cl. Signed-off-by: yash.singh --- api/ccapi/include/layer.h | 11 + nntrainer/cl_context.cpp | 5 + .../layers/cl_layers/addition_layer_cl.cpp | 210 ++++++++++++++++++ .../layers/cl_layers/addition_layer_cl.h | 136 ++++++++++++ nntrainer/layers/cl_layers/meson.build | 3 +- nntrainer/layers/layer_context.cpp | 2 + nntrainer/layers/layer_context.h | 3 +- test/input_gen/gen_layer_tests.py | 9 + test/jni/Android.mk | 1 + .../layers/unittest_layers_addition_cl.cpp | 50 +++++ 10 files changed, 428 insertions(+), 2 deletions(-) create mode 100644 nntrainer/layers/cl_layers/addition_layer_cl.cpp create mode 100644 nntrainer/layers/cl_layers/addition_layer_cl.h create mode 100644 test/unittest/layers/unittest_layers_addition_cl.cpp diff --git a/api/ccapi/include/layer.h b/api/ccapi/include/layer.h index ca0ae19f62..7e76134c5b 100644 --- a/api/ccapi/include/layer.h +++ b/api/ccapi/include/layer.h @@ -359,6 +359,17 @@ Addition(const std::vector &properties = {}) { return createLayer(LayerType::LAYER_ADDITION, properties); } +#ifdef ENABLE_OPENCL +/** + * @brief Helper function to create Addition layer for GPU + */ +inline std::unique_ptr +AdditionCL(const std::vector &properties = {}, + const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) { + return createLayer(LayerType::LAYER_ADDITION, properties, compute_engine); +} +#endif + /** * @brief Helper function to create concat layer */ diff --git a/nntrainer/cl_context.cpp b/nntrainer/cl_context.cpp index be7345eed0..b92a14ca0d 100644 --- a/nntrainer/cl_context.cpp +++ b/nntrainer/cl_context.cpp @@ -12,6 +12,7 @@ * creates the OpenCL command queue and context. */ +#include #include #include @@ -26,6 +27,10 @@ static void add_default_object(ClContext &cc) { cc.registerFactory(nntrainer::createLayer, FullyConnectedLayerCl::type, ml::train::LayerType::LAYER_FC); + + cc.registerFactory(nntrainer::createLayer, + AdditionLayerCL::type, + ml::train::LayerType::LAYER_ADDITION); } static void registerer(ClContext &cc) noexcept { diff --git a/nntrainer/layers/cl_layers/addition_layer_cl.cpp b/nntrainer/layers/cl_layers/addition_layer_cl.cpp new file mode 100644 index 0000000000..48ea84d191 --- /dev/null +++ b/nntrainer/layers/cl_layers/addition_layer_cl.cpp @@ -0,0 +1,210 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Yash Singh + * + * @file addition_layer_cl.cpp + * @date 17 May 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh yash.singh@samsung.com> + * @bug No known bugs except for NYI items + * @brief This is Addition Layer Class Class for Neural Network with OpenCl + * implementation + */ + +#include +#include +#include +#include +#include + +#include + +std::string addition_cl_kernel_ = + R"(__kernel void addition_cl(__global const float* input, __global float* output, const unsigned int size) { + #pragma printf_support + size_t idx = get_global_id(0); + if (idx < size) { + output[idx] = output[idx] + input[idx]; + } +})"; + +namespace nntrainer { + +static constexpr size_t SINGLE_INOUT_IDX = 0; + +void AdditionLayerCL::finalize(InitLayerContext &context) { + context.setOutputDimensions({context.getInputDimensions()[0]}); +} + +void AdditionLayerCL::forwarding(RunLayerContext &context, bool training) { + Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); + + /** @todo check possibility for in-place of addition layer */ + for (unsigned int idx = 0; idx < context.getNumInputs(); ++idx) { + const Tensor &input_ = context.getInput(idx); + if (!idx) { + hidden_.copy(input_); + } else { + // hidden_.add_i(input_); + AddProcess(input_, hidden_, context); + } + } +} + +/** + * @brief declaring static kerinputnel objects + * + */ +opencl::Kernel AdditionLayerCL::kernel_addition; + +void AdditionLayerCL::AddProcess(Tensor const &input, Tensor &result, + RunLayerContext &context) { + + CREATE_IF_EMPTY_DIMS(result, result.getDim()); + + NNTR_THROW_IF(result.getData() == nullptr, std::invalid_argument) + << result.getName() << " is not allocated"; + NNTR_THROW_IF(input.getData() == nullptr, std::invalid_argument) + << input.getName() << " is not allocated"; + + if (input.getDim() != result.getDim()) { + throw std::invalid_argument( + "Error: Dimensions does not match for addition"); + } + + if (input.getDataType() == ml::train::TensorDim::DataType::FP32) { + unsigned int size = input.size(); + const float *data = input.getData(); + float *rdata = result.getData(); + + addition_cl(data, rdata, size, context); + + } else + throw std::invalid_argument("Error: OpenCL fp16 is not supported yet."); +} + +void AdditionLayerCL::addition_cl(const float *input, float *res, + unsigned int size, RunLayerContext &context) { + + bool result = false; + do { + result = result = + context.clCreateKernel(addition_cl_kernel_, context.LayerKernel::ADD, + AdditionLayerCL::kernel_addition); + if (!result) { + break; + } + + size_t dim1_size = sizeof(float) * size; + opencl::Buffer inputA(context.context_inst_, dim1_size, true, nullptr); + + opencl::Buffer inOutRes(context.context_inst_, dim1_size, true, nullptr); + + result = inputA.WriteData(context.command_queue_inst_, input); + if (!result) { + break; + } + + result = inOutRes.WriteData(context.command_queue_inst_, res); + if (!result) { + break; + } + + result = AdditionLayerCL::kernel_addition.SetKernelArguments( + 0, &inputA, sizeof(cl_mem)); + if (!result) { + break; + } + + result = AdditionLayerCL::kernel_addition.SetKernelArguments( + 1, &inOutRes, sizeof(cl_mem)); + if (!result) { + break; + } + + result = AdditionLayerCL::kernel_addition.SetKernelArguments(2, &size, + sizeof(int)); + if (!result) { + break; + } + + const int work_groups_count[3] = {(int)size, 1, 1}; + const int work_group_size[3] = {32, 32, 1}; // test-value + result = context.command_queue_inst_.DispatchCommand( + AdditionLayerCL::kernel_addition, work_groups_count, work_group_size); + if (!result) { + break; + } + + result = inOutRes.ReadData(context.command_queue_inst_, res); + if (!result) { + break; + } + + } while (false); +} + +void AdditionLayerCL::incremental_forwarding(RunLayerContext &context, + unsigned int from, unsigned int to, + bool training) { + Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); + TensorDim hidden_dim = hidden_.getDim(); + TensorDim hidden_step_dim = hidden_dim; + + if (from) { + NNTR_THROW_IF(to - from != 1, std::invalid_argument) + << "incremental step size is not 1"; + from = 0; + to = 1; + } + + hidden_step_dim.batch(1); + hidden_step_dim.height(to - from); + + for (unsigned int b = 0; b < hidden_.batch(); ++b) { + Tensor hidden_step = hidden_.getSharedDataTensor( + hidden_step_dim, b * hidden_dim.getFeatureLen(), true); + + /** @todo check possibility for in-place of addition layer */ + for (unsigned int idx = 0; idx < context.getNumInputs(); ++idx) { + const Tensor &input_ = context.getInput(idx); + TensorDim input_dim = input_.getDim(); + + TensorDim input_step_dim = input_dim; + input_step_dim.batch(1); + input_step_dim.height(to - from); + + Tensor input_step = input_.getSharedDataTensor( + input_step_dim, b * input_dim.getFeatureLen(), true); + if (!idx) { + hidden_step.copy(input_step); + } else { + // hidden_step.add_i(input_step); + AddProcess(input_step, hidden_step, context); + } + } + } +} + +void AdditionLayerCL::calcDerivative(RunLayerContext &context) { + + for (unsigned int idx = 0; idx < context.getNumInputs(); ++idx) { + /** + * TODO: replace this with tensor assignment during optimization. + * Tensor assignment needs to make sure that the previous connected layers + * are not inplace + */ + context.getOutgoingDerivative(idx).copy( + context.getIncomingDerivative(SINGLE_INOUT_IDX)); + } +} + +void AdditionLayerCL::setProperty(const std::vector &values) { + auto remain_props = loadProperties(values, add_props); + if (!remain_props.empty()) { + std::string msg = "[AdditionLayer] Unknown Layer Properties count " + + std::to_string(values.size()); + throw exception::not_supported(msg); + } +} +} /* namespace nntrainer */ diff --git a/nntrainer/layers/cl_layers/addition_layer_cl.h b/nntrainer/layers/cl_layers/addition_layer_cl.h new file mode 100644 index 0000000000..78b9293351 --- /dev/null +++ b/nntrainer/layers/cl_layers/addition_layer_cl.h @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Yash Singh + * + * @file addition_layer_cl.h + * @date 17 May 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh yash.singh@samsung.com> + * @bug No known bugs except for NYI items + * @brief This is Addition Layer Class Class for Neural Network with OpenCl + * implementation + */ + +#ifndef __ADDITION_LAYER_CL_H__ +#define __ADDITION_LAYER_CL_H__ +#ifdef __cplusplus + +#include +#include +#include +#include + +#define CREATE_IF_EMPTY_DIMS(tensor, ...) \ + do { \ + if (tensor.empty()) \ + tensor = Tensor(__VA_ARGS__); \ + } while (0); + +namespace nntrainer { + +/** + * @class AdditionLayerCL + * @brief Addition Layer + */ +class AdditionLayerCL : public Layer { +public: + /** + * @brief Constructor of Addition Layer + */ + AdditionLayerCL() : Layer(), add_props(props::Print()) {} + + /** + * @brief Destructor of Addition Layer + */ + ~AdditionLayerCL(){}; + + /** + * @brief Move constructor of AdditionLayer. + * @param[in] AdditionLayer && + */ + AdditionLayerCL(AdditionLayerCL &&rhs) noexcept = default; + + /** + * @brief Move assignment operator. + * @parma[in] rhs AdditionLayer to be moved. + */ + AdditionLayerCL &operator=(AdditionLayerCL &&rhs) = default; + + /** + * @copydoc Layer::finalize(InitLayerContext &context) + */ + void finalize(InitLayerContext &context) override; + + /** + * @copydoc Layer::forwarding(RunLayerContext &context, bool training) + */ + void forwarding(RunLayerContext &context, bool training) override; + + /** + * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned + * int from, unsigned int to, bool training) + */ + void incremental_forwarding(RunLayerContext &context, unsigned int from, + unsigned int to, bool training) override; + + /** + * @copydoc Layer::calcDerivative(RunLayerContext &context) + */ + void calcDerivative(RunLayerContext &context) override; + + /** + * @brief declaring static kernel objects + */ + static opencl::Kernel kernel_addition; + + /** + * @brief Process data and dimensions for add operation used in addition layer + * @param[in] input Tensor + * @param[in] result Tensor + * @param[in] RunLayerContext reference + */ + void AddProcess(Tensor const &input, Tensor &result, + RunLayerContext &context); + + /** + * @brief addition : sum of all input vectors + * @param[in] input float * for input + * @param[in] res float * for result/output + * @param[in] size number of elements in input vector + * @param[in] context RunLayerContext reference + */ + void addition_cl(const float *input, float *res, unsigned int size, + RunLayerContext &context); + + /** + * @copydoc bool supportBackwarding() const + */ + bool supportBackwarding() const override { return true; }; + + /** + * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods + * method) + */ + void exportTo(Exporter &exporter, + const ml::train::ExportMethods &method) const override {} + + /** + * @copydoc Layer::setProperty(const std::vector &values) + */ + void setProperty(const std::vector &values) override; + + /** + * @copydoc Layer::getType() + */ + const std::string getType() const override { return AdditionLayerCL::type; }; + + std::tuple + add_props; /**< fc layer properties : unit - number of output neurons */ + + inline static const std::string type = "addition"; +}; + +} // namespace nntrainer + +#endif /* __cplusplus */ +#endif /* __ADDITION_LAYER_H__ */ diff --git a/nntrainer/layers/cl_layers/meson.build b/nntrainer/layers/cl_layers/meson.build index 2f1ba7fc03..349e1f443d 100644 --- a/nntrainer/layers/cl_layers/meson.build +++ b/nntrainer/layers/cl_layers/meson.build @@ -1,6 +1,7 @@ cl_layer_sources = [ 'fc_layer_cl.cpp', - 'blas_kernels.cpp' + 'blas_kernels.cpp', + 'addition_layer_cl.cpp' ] foreach s : cl_layer_sources diff --git a/nntrainer/layers/layer_context.cpp b/nntrainer/layers/layer_context.cpp index 92c69f7a67..d3d2851d6a 100644 --- a/nntrainer/layers/layer_context.cpp +++ b/nntrainer/layers/layer_context.cpp @@ -656,6 +656,8 @@ std::string RunLayerContext::getKernelName(LayerKernel layerKernel) { return "dot_cl"; case LayerKernel::SGEMM: return "sgemm_cl"; + case LayerKernel::ADD: + return "addition_cl"; default: return ""; } diff --git a/nntrainer/layers/layer_context.h b/nntrainer/layers/layer_context.h index a3e2a68a8c..4b3c12736f 100644 --- a/nntrainer/layers/layer_context.h +++ b/nntrainer/layers/layer_context.h @@ -832,7 +832,8 @@ class RunLayerContext { enum LayerKernel { SGEMV = 1, /**< placeholder for kernel name */ DOT = 2, /**< placeholder for kernel name */ - SGEMM = 4 /**< placeholder for kernel name */ + SGEMM = 4, /**< placeholder for kernel name */ + ADD = 8 /**< placeholder for kernel name */ }; /** diff --git a/test/input_gen/gen_layer_tests.py b/test/input_gen/gen_layer_tests.py index 48e68acaf1..6ab31ebd33 100644 --- a/test/input_gen/gen_layer_tests.py +++ b/test/input_gen/gen_layer_tests.py @@ -866,3 +866,12 @@ def call(self, inputs): added = K.layers.Add() record_single_fp16(added, [(2, 3, 3, 3), (2, 3, 3, 3)], "added_w16a16") + + added = K.layers.Add() + record_single(added, [(2, 3, 3, 3), (2, 3, 3, 3)], "added_w32a32") + + added = K.layers.Add() + record_single(added, [(3, 4, 3, 4), (3, 4, 3, 4)], "added_w32a32_2") + + added = K.layers.Add() + record_single(added, [(20, 55, 50, 55), (20, 55, 50, 55)], "added_w32a32_3") diff --git a/test/jni/Android.mk b/test/jni/Android.mk index 978e98bd67..963beb3b01 100644 --- a/test/jni/Android.mk +++ b/test/jni/Android.mk @@ -453,6 +453,7 @@ LOCAL_SRC_FILES := \ ../unittest/layers/unittest_layers_flatten.cpp \ ../unittest/layers/unittest_layers_activation.cpp \ ../unittest/layers/unittest_layers_addition.cpp \ + ../unittest/layers/unittest_layers_addition_cl.cpp \ ../unittest/layers/unittest_layers_multiout.cpp \ ../unittest/layers/unittest_layers_rnn.cpp \ ../unittest/layers/unittest_layers_rnncell.cpp \ diff --git a/test/unittest/layers/unittest_layers_addition_cl.cpp b/test/unittest/layers/unittest_layers_addition_cl.cpp new file mode 100644 index 0000000000..a5d6907582 --- /dev/null +++ b/test/unittest/layers/unittest_layers_addition_cl.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Yash Singh + * + * @file unittest_layers_addition_cl.cpp + * @date 17 May 2024 + * @brief Addition Layer Test + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh + * @bug No known bugs except for NYI items + */ +#include + +#include + +#include +#include + +auto semantic_addition_gpu = LayerSemanticsParamType( + nntrainer::createLayer, + nntrainer::AdditionLayerCL::type, {}, + LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 1); + +auto semantic_addition_multi_gpu = LayerSemanticsParamType( + nntrainer::createLayer, + nntrainer::AdditionLayerCL::type, {}, + LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 2); + +GTEST_PARAMETER_TEST(AdditionGPU, LayerSemantics, + ::testing::Values(semantic_addition_gpu, + semantic_addition_multi_gpu)); + +auto addition_w32a32 = LayerGoldenTestParamType( + nntrainer::createLayer, {}, "2:3:3:3,2:3:3:3", + "added_w32a32.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT, "nchw", + "fp32", "fp32"); + +auto addition_w32a32_2 = LayerGoldenTestParamType( + nntrainer::createLayer, {}, "3:4:3:4,3:4:3:4", + "added_w32a32_2.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT, "nchw", + "fp32", "fp32"); + +auto addition_w32a32_3 = LayerGoldenTestParamType( + nntrainer::createLayer, {}, + "20:55:50:55,20:55:50:55", "added_w32a32_3.nnlayergolden", + LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32"); + +GTEST_PARAMETER_TEST(AdditionGPU, LayerGoldenTest, + ::testing::Values(addition_w32a32, addition_w32a32_2, + addition_w32a32_3)); From 11c096b2b08af875a56ad917af8cd87e2c810c5b Mon Sep 17 00:00:00 2001 From: "yash.singh" Date: Tue, 28 May 2024 12:31:53 +0530 Subject: [PATCH 2/2] [GPU/OpenCL] Addition Kernel added in reusable blas OpenCL kernels Added addition kernel to enhance reusability of the common blas kernels. Used AdditionLayer interface for both CPU and GPU calls. Signed-off-by: yash.singh --- api/ccapi/include/layer.h | 16 +--- .../layers/cl_layers/addition_layer_cl.cpp | 81 +------------------ .../layers/cl_layers/addition_layer_cl.h | 19 +---- nntrainer/layers/cl_layers/blas_kernels.cpp | 69 ++++++++++++++++ nntrainer/layers/cl_layers/blas_kernels.h | 11 +++ 5 files changed, 86 insertions(+), 110 deletions(-) diff --git a/api/ccapi/include/layer.h b/api/ccapi/include/layer.h index 7e76134c5b..7fcf1b06d6 100644 --- a/api/ccapi/include/layer.h +++ b/api/ccapi/include/layer.h @@ -354,21 +354,11 @@ Reshape(const std::vector &properties = {}) { /** * @brief Helper function to create addition layer */ -inline std::unique_ptr -Addition(const std::vector &properties = {}) { - return createLayer(LayerType::LAYER_ADDITION, properties); -} - -#ifdef ENABLE_OPENCL -/** - * @brief Helper function to create Addition layer for GPU - */ -inline std::unique_ptr -AdditionCL(const std::vector &properties = {}, - const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) { +inline std::unique_ptr Addition( + const std::vector &properties = {}, + const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) { return createLayer(LayerType::LAYER_ADDITION, properties, compute_engine); } -#endif /** * @brief Helper function to create concat layer diff --git a/nntrainer/layers/cl_layers/addition_layer_cl.cpp b/nntrainer/layers/cl_layers/addition_layer_cl.cpp index 48ea84d191..1cd9f1de41 100644 --- a/nntrainer/layers/cl_layers/addition_layer_cl.cpp +++ b/nntrainer/layers/cl_layers/addition_layer_cl.cpp @@ -3,7 +3,7 @@ * Copyright (C) 2024 Yash Singh * * @file addition_layer_cl.cpp - * @date 17 May 2024 + * @date 28 May 2024 * @see https://github.com/nnstreamer/nntrainer * @author Yash Singh yash.singh@samsung.com> * @bug No known bugs except for NYI items @@ -11,6 +11,7 @@ * implementation */ +#include #include #include #include @@ -19,15 +20,6 @@ #include -std::string addition_cl_kernel_ = - R"(__kernel void addition_cl(__global const float* input, __global float* output, const unsigned int size) { - #pragma printf_support - size_t idx = get_global_id(0); - if (idx < size) { - output[idx] = output[idx] + input[idx]; - } -})"; - namespace nntrainer { static constexpr size_t SINGLE_INOUT_IDX = 0; @@ -45,18 +37,11 @@ void AdditionLayerCL::forwarding(RunLayerContext &context, bool training) { if (!idx) { hidden_.copy(input_); } else { - // hidden_.add_i(input_); AddProcess(input_, hidden_, context); } } } -/** - * @brief declaring static kerinputnel objects - * - */ -opencl::Kernel AdditionLayerCL::kernel_addition; - void AdditionLayerCL::AddProcess(Tensor const &input, Tensor &result, RunLayerContext &context) { @@ -83,67 +68,6 @@ void AdditionLayerCL::AddProcess(Tensor const &input, Tensor &result, throw std::invalid_argument("Error: OpenCL fp16 is not supported yet."); } -void AdditionLayerCL::addition_cl(const float *input, float *res, - unsigned int size, RunLayerContext &context) { - - bool result = false; - do { - result = result = - context.clCreateKernel(addition_cl_kernel_, context.LayerKernel::ADD, - AdditionLayerCL::kernel_addition); - if (!result) { - break; - } - - size_t dim1_size = sizeof(float) * size; - opencl::Buffer inputA(context.context_inst_, dim1_size, true, nullptr); - - opencl::Buffer inOutRes(context.context_inst_, dim1_size, true, nullptr); - - result = inputA.WriteData(context.command_queue_inst_, input); - if (!result) { - break; - } - - result = inOutRes.WriteData(context.command_queue_inst_, res); - if (!result) { - break; - } - - result = AdditionLayerCL::kernel_addition.SetKernelArguments( - 0, &inputA, sizeof(cl_mem)); - if (!result) { - break; - } - - result = AdditionLayerCL::kernel_addition.SetKernelArguments( - 1, &inOutRes, sizeof(cl_mem)); - if (!result) { - break; - } - - result = AdditionLayerCL::kernel_addition.SetKernelArguments(2, &size, - sizeof(int)); - if (!result) { - break; - } - - const int work_groups_count[3] = {(int)size, 1, 1}; - const int work_group_size[3] = {32, 32, 1}; // test-value - result = context.command_queue_inst_.DispatchCommand( - AdditionLayerCL::kernel_addition, work_groups_count, work_group_size); - if (!result) { - break; - } - - result = inOutRes.ReadData(context.command_queue_inst_, res); - if (!result) { - break; - } - - } while (false); -} - void AdditionLayerCL::incremental_forwarding(RunLayerContext &context, unsigned int from, unsigned int to, bool training) { @@ -179,7 +103,6 @@ void AdditionLayerCL::incremental_forwarding(RunLayerContext &context, if (!idx) { hidden_step.copy(input_step); } else { - // hidden_step.add_i(input_step); AddProcess(input_step, hidden_step, context); } } diff --git a/nntrainer/layers/cl_layers/addition_layer_cl.h b/nntrainer/layers/cl_layers/addition_layer_cl.h index 78b9293351..b556746a7c 100644 --- a/nntrainer/layers/cl_layers/addition_layer_cl.h +++ b/nntrainer/layers/cl_layers/addition_layer_cl.h @@ -3,7 +3,7 @@ * Copyright (C) 2024 Yash Singh * * @file addition_layer_cl.h - * @date 17 May 2024 + * @date 28 May 2024 * @see https://github.com/nnstreamer/nntrainer * @author Yash Singh yash.singh@samsung.com> * @bug No known bugs except for NYI items @@ -17,8 +17,6 @@ #include #include -#include -#include #define CREATE_IF_EMPTY_DIMS(tensor, ...) \ do { \ @@ -78,11 +76,6 @@ class AdditionLayerCL : public Layer { */ void calcDerivative(RunLayerContext &context) override; - /** - * @brief declaring static kernel objects - */ - static opencl::Kernel kernel_addition; - /** * @brief Process data and dimensions for add operation used in addition layer * @param[in] input Tensor @@ -92,16 +85,6 @@ class AdditionLayerCL : public Layer { void AddProcess(Tensor const &input, Tensor &result, RunLayerContext &context); - /** - * @brief addition : sum of all input vectors - * @param[in] input float * for input - * @param[in] res float * for result/output - * @param[in] size number of elements in input vector - * @param[in] context RunLayerContext reference - */ - void addition_cl(const float *input, float *res, unsigned int size, - RunLayerContext &context); - /** * @copydoc bool supportBackwarding() const */ diff --git a/nntrainer/layers/cl_layers/blas_kernels.cpp b/nntrainer/layers/cl_layers/blas_kernels.cpp index c190688c66..4f85189b8d 100644 --- a/nntrainer/layers/cl_layers/blas_kernels.cpp +++ b/nntrainer/layers/cl_layers/blas_kernels.cpp @@ -51,12 +51,22 @@ std::string sgemm_cl_kernel_ = C[m * ldc + n] = c; })"; +std::string addition_cl_kernel_ = + R"(__kernel void addition_cl(__global const float* input, __global float* output, const unsigned int size) { + #pragma printf_support + size_t idx = get_global_id(0); + if (idx < size) { + output[idx] = output[idx] + input[idx]; + } + })"; + /** * @brief declaring global kernel objects */ opencl::Kernel kernel_sgemv; opencl::Kernel kernel_sgemm; opencl::Kernel kernel_dot; +opencl::Kernel kernel_addition; void sgemv_cl(const float *matAdata, const float *vecXdata, float *vecYdata, unsigned int dim1, unsigned int dim2, unsigned int lda, @@ -298,4 +308,63 @@ void sgemm_cl(const float *A, const float *B, float *C, unsigned int M, } while (false); } + +void addition_cl(const float *input, float *res, + unsigned int size, RunLayerContext &context) { + + bool result = false; + + do { + result = result = + context.clCreateKernel(addition_cl_kernel_, context.LayerKernel::ADD, + kernel_addition); + if (!result) { + break; + } + + size_t dim1_size = sizeof(float) * size; + opencl::Buffer inputA(context.context_inst_, dim1_size, true, nullptr); + + opencl::Buffer inOutRes(context.context_inst_, dim1_size, true, nullptr); + + result = inputA.WriteData(context.command_queue_inst_, input); + if (!result) { + break; + } + + result = inOutRes.WriteData(context.command_queue_inst_, res); + if (!result) { + break; + } + + result = kernel_addition.SetKernelArguments(0, &inputA, sizeof(cl_mem)); + if (!result) { + break; + } + + result = kernel_addition.SetKernelArguments(1, &inOutRes, sizeof(cl_mem)); + if (!result) { + break; + } + + result = kernel_addition.SetKernelArguments(2, &size, sizeof(int)); + if (!result) { + break; + } + + const int work_groups_count[3] = {(int)size, 1, 1}; + const int work_group_size[3] = {32, 32, 1}; // test-value + result = context.command_queue_inst_.DispatchCommand( + kernel_addition, work_groups_count, work_group_size); + if (!result) { + break; + } + + result = inOutRes.ReadData(context.command_queue_inst_, res); + if (!result) { + break; + } + + } while (false); +} } // namespace nntrainer diff --git a/nntrainer/layers/cl_layers/blas_kernels.h b/nntrainer/layers/cl_layers/blas_kernels.h index ad59b8bbd1..558a3e857e 100644 --- a/nntrainer/layers/cl_layers/blas_kernels.h +++ b/nntrainer/layers/cl_layers/blas_kernels.h @@ -27,6 +27,7 @@ namespace nntrainer { extern opencl::Kernel kernel_sgemv; extern opencl::Kernel kernel_sgemm; extern opencl::Kernel kernel_dot; +extern opencl::Kernel kernel_addition; /** * @brief sgemv computation : Y = A*X + Y @@ -70,5 +71,15 @@ void sgemm_cl(const float *A, const float *B, float *C, unsigned int M, unsigned int N, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc, RunLayerContext &context); +/** + * @brief addition : sum of all input vectors + * @param[in] input float * for input + * @param[in] res float * for result/output + * @param[in] size number of elements in input vector + * @param[in] context RunLayerContext reference + */ +void addition_cl(const float *input, float *res, unsigned int size, + RunLayerContext &context); + } // namespace nntrainer #endif /* __BLAS_KERNELS_H__ */