diff --git a/api/ccapi/include/layer.h b/api/ccapi/include/layer.h index 81afe86ee2..fd2020243b 100644 --- a/api/ccapi/include/layer.h +++ b/api/ccapi/include/layer.h @@ -358,8 +358,9 @@ Flatten(const std::vector &properties = {}) { * @brief Helper function to create reshape layer */ inline std::unique_ptr -Reshape(const std::vector &properties = {}) { - return createLayer(LayerType::LAYER_RESHAPE, properties); +Reshape(const std::vector &properties = {}, + const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) { + return createLayer(LayerType::LAYER_RESHAPE, properties, compute_engine); } /** diff --git a/nntrainer/cl_context.cpp b/nntrainer/cl_context.cpp index 438031d586..1c9a32779a 100644 --- a/nntrainer/cl_context.cpp +++ b/nntrainer/cl_context.cpp @@ -6,7 +6,7 @@ * @date 23 Feb 2024 * @see https://github.com/nnstreamer/nntrainer * @author Debadri Samaddar - * @author Niket Agarwal + * @author Niket Agarwal * @bug No known bugs except for NYI items * @brief This file contains app context related functions and classes that * manages the global configuration of the current OpenCL environment. It also @@ -16,6 +16,7 @@ #include #include #include +#include #include namespace nntrainer { @@ -36,6 +37,9 @@ static void add_default_object(ClContext &cc) { cc.registerFactory(nntrainer::createLayer, SwiGLULayerCl::type, ml::train::LayerType::LAYER_SWIGLU); + + cc.registerFactory(nntrainer::createLayer, + ReshapeLayerCl::type, ml::train::LayerType::LAYER_RESHAPE); } static void registerer(ClContext &cc) noexcept { diff --git a/nntrainer/layers/cl_layers/meson.build b/nntrainer/layers/cl_layers/meson.build index 68622d1c23..aa30060a50 100644 --- a/nntrainer/layers/cl_layers/meson.build +++ b/nntrainer/layers/cl_layers/meson.build @@ -2,6 +2,7 @@ cl_layer_sources = [ 'fc_layer_cl.cpp', 'addition_layer_cl.cpp', 'swiglu_cl.cpp', + 'reshape_cl.cpp', ] foreach s : cl_layer_sources diff --git a/nntrainer/layers/cl_layers/reshape_cl.cpp b/nntrainer/layers/cl_layers/reshape_cl.cpp new file mode 100644 index 0000000000..2ead207a7f --- /dev/null +++ b/nntrainer/layers/cl_layers/reshape_cl.cpp @@ -0,0 +1,326 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Niket Agarwal + * + * @file reshape_cl.cpp + * @date 18 June 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author Niket Agarwal + * @bug No known bugs except for NYI items + * @brief This is Reshape GPU Layer Implementation + */ + +#include +#include +#include +#include +#include +#include + +std::string reshape_cl_kernel_fp16_ = + R"( + #pragma OPENCL EXTENSION cl_khr_fp16 : enable + __kernel void reshape_cl_fp16(__global const half* input, + __global half* output, + const int batchsize, + const int channels, + const int height, + const int width) { + + int i= get_global_id(0); + output[i] = input[i]; + +})"; + +std::string reshape_cl_kernel_ = + R"(__kernel void reshape_cl(__global const float* input, + __global float* output, + const int batchsize, + const int channels, + const int height, + const int width) { + + int i= get_global_id(0); + output[i] = input[i]; + +})"; + +namespace nntrainer { + +static constexpr size_t SINGLE_INOUT_IDX = 0; + +void ReshapeLayerCl::finalize(InitLayerContext &context) { + NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument) + << "Reshape only supports 1 input for now"; + + const TensorDim &in_dim = context.getInputDimensions()[0]; + + auto &target_shape = std::get(reshape_props); + NNTR_THROW_IF(target_shape.empty(), std::invalid_argument) + << "Reshape layer must be provided with target shape"; + TensorDim out_dim = target_shape.get(); + + if ((int)out_dim.getDataLen() == -1) { + out_dim.height(1); + out_dim.channel(1); + out_dim.width(in_dim.getFeatureLen()); + } else if (out_dim.getFeatureLen() != in_dim.getFeatureLen()) { + throw std::invalid_argument( + "Target and input size mismatch for reshape layer"); + } + + out_dim.batch(in_dim.batch()); + + context.setOutputDimensions({out_dim}); +} + +void ReshapeLayerCl::forwarding(RunLayerContext &context, bool training) { + if (!context.executeInPlace()) { + Tensor &output = context.getOutput(SINGLE_INOUT_IDX); + const Tensor &input = context.getInput(SINGLE_INOUT_IDX); + ReshapeProcess(input, output, context); + } +} + +void ReshapeLayerCl::incremental_forwarding(RunLayerContext &context, + unsigned int from, unsigned int to, + bool training) { + if (!context.executeInPlace()) { + Tensor &output = context.getOutput(SINGLE_INOUT_IDX); + const Tensor &input = context.getInput(SINGLE_INOUT_IDX); + if (from) { + NNTR_THROW_IF(to - from != 1, std::invalid_argument) + << "incremental step size is not 1"; + from = 0; + to = 1; + } + ReshapeProcess(input, output, context); + } +} + +opencl::Kernel ReshapeLayerCl::kernel_reshape; +opencl::Kernel ReshapeLayerCl::kernel_reshape_fp16; + +void ReshapeLayerCl::ReshapeProcess(Tensor const &input, Tensor &output, + RunLayerContext &context) { + + unsigned int input_batch_size, input_height, input_width, input_channels; + + input_batch_size = input.batch(); + input_height = input.height(); + input_width = input.width(); + input_channels = input.channel(); + + if (input.getDataType() == ml::train::TensorDim::DataType::FP32) { + const float *data = input.getData(); + float *rdata = output.getData(); + reshape_cl(data, rdata, input_batch_size, input_channels, input_height, + input_width, context); + } else if (input.getDataType() == ml::train::TensorDim::DataType::FP16) { +#ifdef ENABLE_FP16 + const _FP16 *data = input.getData<_FP16>(); + _FP16 *rdata = output.getData<_FP16>(); + reshape_cl_fp16(data, rdata, input_batch_size, input_channels, input_height, + input_width, context); +#else + throw std::invalid_argument("Error: enable-fp16 is not enabled"); +#endif + } +} + +void ReshapeLayerCl::reshape_cl(const float *input, float *res, + unsigned int input_batch_size, + unsigned int input_channels, + unsigned int input_height, + unsigned int input_width, + RunLayerContext &context) { + + bool result = false; + + do { + result = + context.clCreateKernel(reshape_cl_kernel_, context.LayerKernel::RESHAPE, + ReshapeLayerCl::kernel_reshape); + if (!result) { + break; + } + + size_t dim_size = sizeof(float) * input_batch_size * input_height * + input_width * input_channels; + + opencl::Buffer inputA(context.context_inst_, dim_size, true, nullptr); + + opencl::Buffer inOutRes(context.context_inst_, dim_size, true, nullptr); + + result = inputA.WriteData(context.command_queue_inst_, input); + if (!result) { + break; + } + + result = inOutRes.WriteData(context.command_queue_inst_, res); + if (!result) { + break; + } + + result = ReshapeLayerCl::kernel_reshape.SetKernelArguments(0, &inputA, + sizeof(cl_mem)); + if (!result) { + break; + } + + result = ReshapeLayerCl::kernel_reshape.SetKernelArguments(1, &inOutRes, + sizeof(cl_mem)); + if (!result) { + break; + } + + result = ReshapeLayerCl::kernel_reshape.SetKernelArguments( + 2, &input_batch_size, sizeof(int)); + if (!result) { + break; + } + + result = ReshapeLayerCl::kernel_reshape.SetKernelArguments( + 3, &input_channels, sizeof(int)); + if (!result) { + break; + } + + result = ReshapeLayerCl::kernel_reshape.SetKernelArguments(4, &input_height, + sizeof(int)); + if (!result) { + break; + } + + result = ReshapeLayerCl::kernel_reshape.SetKernelArguments(5, &input_width, + sizeof(int)); + if (!result) { + break; + } + + const int work_groups_count[3] = {(int)dim_size, 1, 1}; + const int work_group_size[3] = {32, 32, 1}; // test-value + + result = context.command_queue_inst_.DispatchCommand( + ReshapeLayerCl::kernel_reshape, work_groups_count, work_group_size); + if (!result) { + break; + } + + result = inOutRes.ReadData(context.command_queue_inst_, res); + if (!result) { + break; + } + + } while (false); +} + +void ReshapeLayerCl::reshape_cl_fp16(const __fp16 *input, __fp16 *res, + unsigned int input_batch_size, + unsigned int input_channels, + unsigned int input_height, + unsigned int input_width, + RunLayerContext &context) { + + bool result = false; + + do { + result = context.clCreateKernel(reshape_cl_kernel_fp16_, + context.LayerKernel::RESHAPE_FP16, + ReshapeLayerCl::kernel_reshape_fp16); + if (!result) { + break; + } + + size_t dim_size = sizeof(__fp16) * input_batch_size * input_height * + input_width * input_channels; + + opencl::Buffer inputA(context.context_inst_, dim_size, true, nullptr); + + opencl::Buffer inOutRes(context.context_inst_, dim_size, true, nullptr); + + result = inputA.WriteData(context.command_queue_inst_, input); + if (!result) { + break; + } + + result = inOutRes.WriteData(context.command_queue_inst_, res); + if (!result) { + break; + } + + result = ReshapeLayerCl::kernel_reshape_fp16.SetKernelArguments( + 0, &inputA, sizeof(cl_mem)); + if (!result) { + break; + } + + result = ReshapeLayerCl::kernel_reshape_fp16.SetKernelArguments( + 1, &inOutRes, sizeof(cl_mem)); + if (!result) { + break; + } + + result = ReshapeLayerCl::kernel_reshape_fp16.SetKernelArguments( + 2, &input_batch_size, sizeof(int)); + if (!result) { + break; + } + + result = ReshapeLayerCl::kernel_reshape_fp16.SetKernelArguments( + 3, &input_channels, sizeof(int)); + if (!result) { + break; + } + + result = ReshapeLayerCl::kernel_reshape_fp16.SetKernelArguments( + 4, &input_height, sizeof(int)); + if (!result) { + break; + } + + result = ReshapeLayerCl::kernel_reshape_fp16.SetKernelArguments( + 5, &input_width, sizeof(int)); + if (!result) { + break; + } + + const int work_groups_count[3] = {(int)dim_size, 1, 1}; + const int work_group_size[3] = {32, 32, 1}; // test-value + + result = context.command_queue_inst_.DispatchCommand( + ReshapeLayerCl::kernel_reshape_fp16, work_groups_count, work_group_size); + if (!result) { + break; + } + + result = inOutRes.ReadData(context.command_queue_inst_, res); + if (!result) { + break; + } + + } while (false); +} + +void ReshapeLayerCl::calcDerivative(RunLayerContext &context) { + if (!context.executeInPlace()) { + context.getOutgoingDerivative(SINGLE_INOUT_IDX) + .copyData(context.getIncomingDerivative(SINGLE_INOUT_IDX)); + } +} + +void ReshapeLayerCl::setProperty(const std::vector &values) { + auto remain_props = loadProperties(values, reshape_props); + if (!remain_props.empty()) { + std::string msg = "[ReshapeLayer] Unknown Layer Properties count " + + std::to_string(remain_props.size()); + throw exception::not_supported(msg); + } +} + +void ReshapeLayerCl::exportTo(Exporter &exporter, + const ml::train::ExportMethods &method) const { + exporter.saveResult(reshape_props, method, this); +} + +} /* namespace nntrainer */ diff --git a/nntrainer/layers/cl_layers/reshape_cl.h b/nntrainer/layers/cl_layers/reshape_cl.h new file mode 100644 index 0000000000..8a6644978a --- /dev/null +++ b/nntrainer/layers/cl_layers/reshape_cl.h @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Niket Agarwal + * + * @file reshape_cl.h + * @date 18 June 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author Niket Agarwal + * @bug No known bugs except for NYI items + * @brief This is Reshape GPU Layer Implementation + * + */ + +#ifndef __RESHAPE_LAYER_CL_H__ +#define __RESHAPE_LAYER_CL_H__ +#ifdef __cplusplus + +#include +#include +#include +#include + +namespace nntrainer { +/** + * @class Reshape Layer + * @brief Reshape Layer + */ +class ReshapeLayerCl : public Layer { +public: + /** + * @brief Constructor of Reshape Layer + */ + ReshapeLayerCl() : Layer() {} + + /** + * @brief Destructor of Reshape Layer + */ + ~ReshapeLayerCl() = default; + + /** + * @brief Move constructor of ReshapeLayer. + * @param[in] ReshapeLayer && + */ + ReshapeLayerCl(ReshapeLayerCl &&rhs) noexcept = default; + + /** + * @brief Move assignment operator. + * @parma[in] rhs ReshapeLayer to be moved. + */ + ReshapeLayerCl &operator=(ReshapeLayerCl &&rhs) = default; + + /** + * @copydoc Layer::finalize(InitLayerContext &context) + */ + void finalize(InitLayerContext &context) override; + + /** + * @copydoc Layer::forwarding(RunLayerContext &context, bool training) + */ + void forwarding(RunLayerContext &context, bool training) override; + + /** + * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned + int from, unsigned int to, bool training) + */ + void incremental_forwarding(RunLayerContext &context, unsigned int from, + unsigned int to, bool training) override; + + /** + * @copydoc Layer::calcDerivative(RunLayerContext &context) + */ + void calcDerivative(RunLayerContext &context) override; + + /** + * @copydoc Layer::setProperty(const std::vector &values) + */ + void setProperty(const std::vector &values) override; + + /** + * @copydoc bool supportBackwarding() const + */ + bool supportBackwarding() const override { return false; }; + + /** + * @copydoc Layer::supportInPlace() + */ + bool supportInPlace() const override { return true; } + + /** + * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods + * method) + */ + void exportTo(Exporter &exporter, + const ml::train::ExportMethods &method) const override; + + /** + * @copydoc Layer::getType() + */ + const std::string getType() const override { return ReshapeLayerCl::type; }; + + inline static const std::string type = "reshape"; + + static opencl::Kernel kernel_reshape; + static opencl::Kernel kernel_reshape_fp16; + + /** + * @brief Process data and dimensions for reshape operation + * @param[in] input Tensor + * @param[in] result Tensor + * @param[in] RunLayerContext reference + */ + void ReshapeProcess(Tensor const &input, Tensor &result, + RunLayerContext &context); + + /** + * @brief reshape computation + * @param[in] input float * for Input Tensor + * @param[in] res float * for Output Tensor + * @param[in] input_batch_size represents the number of samples in the input + * tensor + * @param[in] input_channels represents the channels of the input tensor + * @param[in] input_height represents the height of the input tensor + * @param[in] input_width represents the width of the input tensor + * @param[in] context RunLayerContext reference + */ + void reshape_cl(const float *input, float *res, unsigned int input_batch_size, + unsigned int input_channels, unsigned int input_height, + unsigned int input_width, RunLayerContext &context); + + /** + * @brief reshape computation + * @param[in] input fp16 * for Input Tensor + * @param[in] res fp16 * for Output Tensor + * @param[in] input_batch_size represents the number of samples in the input + * tensor + * @param[in] input_channels represents the channels of the input tensor + * @param[in] input_height represents the height of the input tensor + * @param[in] input_width represents the width of the input tensor + * @param[in] context RunLayerContext reference + */ + void reshape_cl_fp16(const __fp16 *input, __fp16 *res, + unsigned int input_batch_size, + unsigned int input_channels, unsigned int input_height, + unsigned int input_width, RunLayerContext &context); + +protected: + std::tuple + reshape_props; /**< reshape properties : target_shape after reshape */ +}; + +} // namespace nntrainer + +#endif /* __cplusplus */ +#endif /* __RESHAPE_LAYER_CL_H__ */ diff --git a/nntrainer/layers/layer_context.cpp b/nntrainer/layers/layer_context.cpp index b959a0af20..93267adc7b 100644 --- a/nntrainer/layers/layer_context.cpp +++ b/nntrainer/layers/layer_context.cpp @@ -699,6 +699,10 @@ std::string RunLayerContext::getKernelName(LayerKernel layerKernel) { return "swiglu_cl"; case LayerKernel::SWIGLU_FP16: return "swiglu_cl_fp16"; + case LayerKernel::RESHAPE: + return "reshape_cl"; + case LayerKernel::RESHAPE_FP16: + return "reshape_cl_fp16"; default: return ""; } diff --git a/nntrainer/layers/layer_context.h b/nntrainer/layers/layer_context.h index fc0ee91f49..57d98dc5bd 100644 --- a/nntrainer/layers/layer_context.h +++ b/nntrainer/layers/layer_context.h @@ -830,16 +830,18 @@ class RunLayerContext { * getKernelName function. */ enum LayerKernel { - SGEMV = 1 << 0, /**< placeholder for kernel name */ - DOT = 1 << 1, /**< placeholder for kernel name */ - SGEMM = 1 << 2, /**< placeholder for kernel name */ - SGEMV_FP16 = 1 << 3, /**< placeholder for kernel name */ - DOT_FP16 = 1 << 4, /**< placeholder for kernel name */ - SGEMM_FP16 = 1 << 5, /**< placeholder for kernel name */ - ADD = 1 << 6, /**< placeholder for kernel name */ - ADD_FP16 = 1 << 7, /**< placeholder for kernel name */ - SWIGLU = 1 << 8, /**< placeholder for kernel name */ - SWIGLU_FP16 = 1 << 9 /**< placeholder for kernel name */ + SGEMV = 1 << 0, /**< placeholder for kernel name */ + DOT = 1 << 1, /**< placeholder for kernel name */ + SGEMM = 1 << 2, /**< placeholder for kernel name */ + SGEMV_FP16 = 1 << 3, /**< placeholder for kernel name */ + DOT_FP16 = 1 << 4, /**< placeholder for kernel name */ + SGEMM_FP16 = 1 << 5, /**< placeholder for kernel name */ + ADD = 1 << 6, /**< placeholder for kernel name */ + ADD_FP16 = 1 << 7, /**< placeholder for kernel name */ + SWIGLU = 1 << 8, /**< placeholder for kernel name */ + SWIGLU_FP16 = 1 << 9, /**< placeholder for kernel name */ + RESHAPE = 1 << 10, /**< placeholder for kernel name */ + RESHAPE_FP16 = 1 << 11 /**< placeholder for kernel name */ }; /** diff --git a/test/input_gen/gen_layer_tests.py b/test/input_gen/gen_layer_tests.py index 99017d071f..7b0a701206 100644 --- a/test/input_gen/gen_layer_tests.py +++ b/test/input_gen/gen_layer_tests.py @@ -897,3 +897,13 @@ def swiglu(inputs): "swiglu", input_type="float", ) + + def reshape_tensor(tensor, batch_size, input_channel, input_height, input_width): + output_height = 1 + output_width = input_channel * input_height * input_width + output_channel = 1 + output_shape = (batch_size, output_channel, output_height, output_width) + return tf.reshape(tensor, output_shape) + + reshape_layer = tf.keras.layers.Lambda(lambda x: reshape_tensor(x, 2, 3, 3, 3)) + record_single(reshape_layer, (2, 3, 3, 3), "reshape", input_type="float") diff --git a/test/jni/Android.mk b/test/jni/Android.mk index 2e947e5289..367ba499ac 100644 --- a/test/jni/Android.mk +++ b/test/jni/Android.mk @@ -445,6 +445,7 @@ LOCAL_SRC_FILES := \ ../unittest/layers/unittest_layers_fully_connected_cl.cpp \ ../unittest/layers/unittest_layers_input.cpp \ ../unittest/layers/unittest_layers_loss.cpp \ + ../unittest/layers/unittest_layers_reshape_cl.cpp \ ../unittest/layers/unittest_layers_fully_connected.cpp \ ../unittest/layers/unittest_layers_batch_normalization.cpp \ ../unittest/layers/unittest_layers_layer_normalization.cpp \ diff --git a/test/unittest/layers/unittest_layers_reshape_cl.cpp b/test/unittest/layers/unittest_layers_reshape_cl.cpp new file mode 100644 index 0000000000..ae0a779039 --- /dev/null +++ b/test/unittest/layers/unittest_layers_reshape_cl.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Niket Agarwal + * + * @file unittest_layers_reshape_cl.cpp + * @date 18th June 2024 + * @brief Reshape Layer Test + * @see https://github.com/nnstreamer/nntrainer + * @author Niket Agarwal + * @bug No known bugs except for NYI items + */ +#include + +#include + +#include +#include + +auto semantic_reshape_gpu = LayerSemanticsParamType( + nntrainer::createLayer, + nntrainer::ReshapeLayerCl::type, {"target_shape=-1"}, + LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 1); + +GTEST_PARAMETER_TEST(ReshapeGPU, LayerSemanticsGpu, + ::testing::Values(semantic_reshape_gpu)); + +auto reshape_basic_plain = LayerGoldenTestParamType( + nntrainer::createLayer, {"target_shape=-1"}, + "2:3:3:3", "reshape.nnlayergolden", + LayerGoldenTestParamOptions::SKIP_CALC_DERIV | + LayerGoldenTestParamOptions::SKIP_CALC_GRAD | + LayerGoldenTestParamOptions::USE_INC_FORWARD, + "nchw", "fp32", "fp32"); + +GTEST_PARAMETER_TEST(ReshapeGPU, LayerGoldenTest, + ::testing::Values(reshape_basic_plain));