From 5d2614ecef97a1a4cca1557385322e1e8f351fdb Mon Sep 17 00:00:00 2001 From: Niket Agarwal Date: Wed, 9 Oct 2024 12:00:40 +0530 Subject: [PATCH] [GPU/OpenCL] Initial version of Transpose (all axes) with OpenCL ops Added naive version of OpenCL implementation for Transpose. Incorporated kernel for ops using blas_kernels. Added unit test for Transpose_cl. Signed-off-by: Niket Agarwal --- Applications/LLaMA/jni/transpose_layer.h | 2 +- api/ccapi/include/layer.h | 14 +- api/nntrainer-api-common.h | 7 +- nntrainer/cl_context.cpp | 5 + nntrainer/layers/cl_layers/meson.build | 1 + nntrainer/layers/cl_layers/transpose_cl.cpp | 91 ++++++ nntrainer/layers/cl_layers/transpose_cl.h | 105 +++++++ .../cl_operations/blas_kernel_interface.cpp | 54 ++++ .../cl_operations/blas_kernel_interface.h | 9 + .../cl_operations/blas_kernel_strings.h | 144 ++++++++++ .../tensor/cl_operations/blas_kernels.cpp | 264 ++++++++++++++++++ nntrainer/tensor/cl_operations/blas_kernels.h | 90 ++++++ .../cl_operations/blas_kernels_fp16.cpp | 264 ++++++++++++++++++ test/input_gen/gen_layer_tests.py | 24 ++ test/jni/Android.mk | 1 + .../layers/unittest_layers_transpose_cl.cpp | 83 ++++++ 16 files changed, 1152 insertions(+), 6 deletions(-) create mode 100644 nntrainer/layers/cl_layers/transpose_cl.cpp create mode 100644 nntrainer/layers/cl_layers/transpose_cl.h create mode 100644 test/unittest/layers/unittest_layers_transpose_cl.cpp diff --git a/Applications/LLaMA/jni/transpose_layer.h b/Applications/LLaMA/jni/transpose_layer.h index a2ae7d8d64..bd5c4db5a6 100644 --- a/Applications/LLaMA/jni/transpose_layer.h +++ b/Applications/LLaMA/jni/transpose_layer.h @@ -58,7 +58,7 @@ class TransposeLayer final : public nntrainer::Layer { /** * @copydoc bool supportBackwarding() const */ - bool supportBackwarding() const override { return true; }; + bool supportBackwarding() const override { return false; }; /** * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method) diff --git a/api/ccapi/include/layer.h b/api/ccapi/include/layer.h index d9f9cffdd2..04d7710039 100644 --- a/api/ccapi/include/layer.h +++ b/api/ccapi/include/layer.h @@ -102,8 +102,9 @@ enum LayerType { LAYER_LOSS_CONSTANT_DERIVATIVE, /**< Synthetic loss layer to feed constant derivative */ LAYER_UPSAMPLE2D, /**< Upsample 2D Layer type */ - LAYER_RMSNORM = ML_TRAIN_LAYER_TYPE_RMSNORM, /** &properties = {}, return createLayer(LayerType::LAYER_RMSNORM, properties, compute_engine); } +/** + * @brief Helper function to create Transpose layer + */ +inline std::unique_ptr +Transpose(const std::vector &properties = {}, + const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) { + return createLayer(LayerType::LAYER_TRANSPOSE, properties, compute_engine); +} + /** * @brief Helper function to create batch normalization layer */ diff --git a/api/nntrainer-api-common.h b/api/nntrainer-api-common.h index 6604ed9494..d9c51a360c 100644 --- a/api/nntrainer-api-common.h +++ b/api/nntrainer-api-common.h @@ -62,9 +62,10 @@ typedef enum { 27, /**< Layer Normalization Layer type (Since 7.0) */ ML_TRAIN_LAYER_TYPE_POSITIONAL_ENCODING = 28, /**< Positional Encoding Layer type (Since 7.0) */ - ML_TRAIN_LAYER_TYPE_IDENTITY = 29, /**< Identity Layer type (Since 8.0) */ - ML_TRAIN_LAYER_TYPE_SWIGLU = 30, /**< Swiglu Layer type */ - ML_TRAIN_LAYER_TYPE_WEIGHT = 31, /**< Weight Layer type (Since 9.0)*/ + ML_TRAIN_LAYER_TYPE_IDENTITY = 29, /**< Identity Layer type (Since 8.0) */ + ML_TRAIN_LAYER_TYPE_SWIGLU = 30, /**< Swiglu Layer type */ + ML_TRAIN_LAYER_TYPE_WEIGHT = 31, /**< Weight Layer type (Since 9.0)*/ + ML_TRAIN_LAYER_TYPE_TRANSPOSE = 32, /**< Transpose Layer type */ ML_TRAIN_LAYER_TYPE_PREPROCESS_FLIP = 300, /**< Preprocess flip Layer (Since 6.5) */ ML_TRAIN_LAYER_TYPE_PREPROCESS_TRANSLATE = diff --git a/nntrainer/cl_context.cpp b/nntrainer/cl_context.cpp index 10e3ecdbb7..c02bac016d 100644 --- a/nntrainer/cl_context.cpp +++ b/nntrainer/cl_context.cpp @@ -23,6 +23,7 @@ #include #include #include +#include namespace nntrainer { @@ -51,6 +52,10 @@ static void add_default_object(ClContext &cc) { cc.registerFactory(nntrainer::createLayer, ConcatLayerCl::type, ml::train::LayerType::LAYER_CONCAT); + + cc.registerFactory(nntrainer::createLayer, + TransposeLayerCl::type, + ml::train::LayerType::LAYER_TRANSPOSE); } static void registerer(ClContext &cc) noexcept { diff --git a/nntrainer/layers/cl_layers/meson.build b/nntrainer/layers/cl_layers/meson.build index fbfd46961b..7fc19154e0 100644 --- a/nntrainer/layers/cl_layers/meson.build +++ b/nntrainer/layers/cl_layers/meson.build @@ -5,6 +5,7 @@ cl_layer_sources = [ 'reshape_cl.cpp', 'rmsnorm_layer_cl.cpp', 'concat_cl.cpp', + 'transpose_cl.cpp', ] foreach s : cl_layer_sources diff --git a/nntrainer/layers/cl_layers/transpose_cl.cpp b/nntrainer/layers/cl_layers/transpose_cl.cpp new file mode 100644 index 0000000000..00ac351a47 --- /dev/null +++ b/nntrainer/layers/cl_layers/transpose_cl.cpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Niket Agarwal + * + * @file transpose_cl.cpp + * @date 31 July 2024 + * @brief Implementation of transpose layer + * @see https://github.com/nnstreamer/nntrainer + * @author Niket Agarwal + * @bug No known bugs except for NYI items + * + */ + +#include "transpose_cl.h" +#include +#include +#include +#include +#include +#include + +namespace nntrainer { + +static constexpr size_t SINGLE_INOUT_IDX = 0; + +void TransposeLayerCl::finalize(InitLayerContext &context) { + std::vector dim = context.getInputDimensions(); + + for (unsigned int i = 0; i < dim.size(); ++i) { + if (dim[i].getDataLen() == 0) { + throw std::invalid_argument("Input dimension is not set"); + } else { + dim[i].channel(dim[i].channel()); + dim[i].height(dim[i].height()); + dim[i].width(dim[i].width()); + } + } + + context.setOutputDimensions(dim); +} + +void TransposeLayerCl::forwarding(RunLayerContext &context, bool training) { + Tensor &in = context.getInput(SINGLE_INOUT_IDX); + Tensor &out = context.getOutput(SINGLE_INOUT_IDX); + transposeCl("1:0:2", in, out); +} + +void TransposeLayerCl::incremental_forwarding(RunLayerContext &context, + unsigned int from, + unsigned int to, bool training) { + Tensor &in = context.getInput(SINGLE_INOUT_IDX); + Tensor &out = context.getOutput(SINGLE_INOUT_IDX); + if (from) { + NNTR_THROW_IF(to - from != 1, std::invalid_argument) + << "incremental step size is not 1"; + from = 0; + to = 1; + } + transposeCl("1:0:2", in, out); +} + +void TransposeLayerCl::calcDerivative(RunLayerContext &context) { + std::throw_with_nested(std::runtime_error("Training is not supported yet.")); +} + +void TransposeLayerCl::setProperty(const std::vector &values) { + auto remain_props = loadProperties(values, transpose_props); + if (!remain_props.empty()) { + std::string msg = "[TransposeLayerCl] Unknown Layer Properties count " + + std::to_string(values.size()); + throw exception::not_supported(msg); + } +} + +#ifdef PLUGGABLE + +Layer *create_transpose_layer_cl() { + auto layer = new TransposeLayerCl(); + return layer; +} + +void destroy_transpose_layer_cl(Layer *layer) { delete layer; } + +extern "C" { +LayerPluggable ml_train_layer_pluggable{create_transpose_layer_cl, + destroy_transpose_layer_cl}; +} + +#endif + +} // namespace nntrainer diff --git a/nntrainer/layers/cl_layers/transpose_cl.h b/nntrainer/layers/cl_layers/transpose_cl.h new file mode 100644 index 0000000000..2c6c9efa28 --- /dev/null +++ b/nntrainer/layers/cl_layers/transpose_cl.h @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Niket Agarwal + * + * @file transpose_cl.h + * @date 31 July 2024 + * @brief Implementation of transpose layer + * @see https://github.com/nnstreamer/nntrainer + * @author Niket Agarwal + * @bug No known bugs except for NYI items + * + */ + +#ifndef __TRANSPOSE_LAYER_CL_H__ +#define __TRANSPOSE_LAYER_CL_H__ + +#include +#include +#include +#include + +#define CREATE_IF_EMPTY_DIMS(tensor, ...) \ + do { \ + if (tensor.empty()) \ + tensor = Tensor(__VA_ARGS__); \ + } while (0); + +namespace nntrainer { + +/** + * @brief A tranpose layer. + * + */ +class TransposeLayerCl final : public Layer { +public: + /** + * @brief Construct a new transpose layer object + * + */ + TransposeLayerCl() : Layer(), transpose_props(props::Print()) {} + + /** + * @brief Destroy the transpose layer object + * + */ + ~TransposeLayerCl() {} + + /** + * @copydoc Layer::finalize(InitLayerContext &context) + */ + void finalize(InitLayerContext &context) override; + + /** + * @copydoc Layer::forwarding(RunLayerContext &context, bool training) + */ + void forwarding(RunLayerContext &context, bool training) override; + + /** + * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned + * int from, unsigned int to, bool training) + */ + void incremental_forwarding(RunLayerContext &context, unsigned int from, + unsigned int to, bool training) override; + + /** + * @copydoc Layer::calcDerivative(RunLayerContext &context) + */ + void calcDerivative(RunLayerContext &context) override; + + /** + * @copydoc bool supportBackwarding() const + */ + bool supportBackwarding() const override { return true; }; + + /** + * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method) + */ + void exportTo(Exporter &exporter, + const ml::train::ExportMethods &method) const override{}; + + /** + * @copydoc Layer::getType() + */ + const std::string getType() const override { return TransposeLayerCl::type; }; + + /** + * @copydoc Layer::setProperty(const std::vector &values) + */ + void setProperty(const std::vector &values) override; + + inline static const std::string type = "transpose"; + + static opencl::Kernel kernel_transpose_axis0; + static opencl::Kernel kernel_transpose_fp16_axis0; + static opencl::Kernel kernel_transpose_axis1; + static opencl::Kernel kernel_transpose_fp16_axis1; + static opencl::Kernel kernel_transpose_axis2; + static opencl::Kernel kernel_transpose_fp16_axis2; + + std::tuple transpose_props; /**< transpose layer properties : + unit - number of output neurons */ +}; +} // namespace nntrainer + +#endif /* __TRANSPOSE_LAYER_CL_H__ */ diff --git a/nntrainer/tensor/cl_operations/blas_kernel_interface.cpp b/nntrainer/tensor/cl_operations/blas_kernel_interface.cpp index 23af3f9799..75034314fc 100644 --- a/nntrainer/tensor/cl_operations/blas_kernel_interface.cpp +++ b/nntrainer/tensor/cl_operations/blas_kernel_interface.cpp @@ -235,4 +235,58 @@ void add_i_cl(Tensor const &input, Tensor &result) { } } +void transposeCl(const std::string &direction, Tensor const &in, + Tensor &result) { + + unsigned int input_batch_size, input_height, input_width, input_channels; + + input_batch_size = in.batch(); + input_height = in.height(); + input_width = in.width(); + input_channels = in.channel(); + + if (in.getDataType() == ml::train::TensorDim::DataType::FP32) { + const float *data = in.getData(); + float *rdata = result.getData(); + // for transpose about channels and height + if (direction[0] == '1' && direction[2] == '0') { + transpose_cl_axis0(data, rdata, input_batch_size, input_channels, + input_height, input_width); + } + // for transpose about height and width + else if (direction[0] == '0' && direction[2] == '2') { + transpose_cl_axis1(data, rdata, input_batch_size, input_channels, + input_height, input_width); + } + // for transpose about channels and width + else if (direction[0] == '2' && direction[2] == '1') { + transpose_cl_axis2(data, rdata, input_batch_size, input_channels, + input_height, input_width); + } + + } else if (in.getDataType() == ml::train::TensorDim::DataType::FP16) { +#ifdef ENABLE_FP16 + const _FP16 *data = in.getData<_FP16>(); + _FP16 *rdata = result.getData<_FP16>(); + // for transpose about channels and height + if (direction[0] == '1' && direction[2] == '0') { + transpose_cl_axis0(data, rdata, input_batch_size, input_channels, + input_height, input_width); + } + // for transpose about height and width + else if (direction[0] == '0' && direction[2] == '2') { + transpose_cl_axis1(data, rdata, input_batch_size, input_channels, + input_height, input_width); + } + // for transpose about channels and width + else if (direction[0] == '2' && direction[2] == '1') { + transpose_cl_axis2(data, rdata, input_batch_size, input_channels, + input_height, input_width); + } +#else + throw std::invalid_argument("Error: enable-fp16 is not enabled"); +#endif + } +} + } // namespace nntrainer diff --git a/nntrainer/tensor/cl_operations/blas_kernel_interface.h b/nntrainer/tensor/cl_operations/blas_kernel_interface.h index 0b8d29a53c..b5ae4ee40e 100644 --- a/nntrainer/tensor/cl_operations/blas_kernel_interface.h +++ b/nntrainer/tensor/cl_operations/blas_kernel_interface.h @@ -70,5 +70,14 @@ void multiplyCl(Tensor &input, float const &value); */ void add_i_cl(Tensor const &input, Tensor &result); +/** + * @brief Process data and dimensions for transpose operation + * @param[in] direction string + * @param[in] input Tensor + * @param[in] result Tensor + */ +void transposeCl(const std::string &direction, Tensor const &in, + Tensor &result); + } // namespace nntrainer #endif /* __BLAS_KERNEL_INTERFACE_H__ */ diff --git a/nntrainer/tensor/cl_operations/blas_kernel_strings.h b/nntrainer/tensor/cl_operations/blas_kernel_strings.h index 616900b719..7b26e56b90 100644 --- a/nntrainer/tensor/cl_operations/blas_kernel_strings.h +++ b/nntrainer/tensor/cl_operations/blas_kernel_strings.h @@ -121,6 +121,75 @@ static const std::string sscal_cl_kernel_ = X[i] *= alpha; })"; +static const std::string transpose_cl_kernel_axis0 = + R"(__kernel void transpose_cl_axis0(__global const float* in, + __global float* output, + const int batch_size, + const int channels, + const int height, + const int width) { + // Calculate h and w from the global IDs + int h = get_global_id(0); + int w = get_global_id(1); + if (h < height && w < width) { + for (int c = 0; c < channels; ++c) { + for (int n = 0; n < batch_size; ++n) { + // Calculate the input and output indices + int input_index = n * (channels * height * width) + c * (height * width) + h * width + w; + int output_index = n * (channels * height * width) + h * (channels * width) + c * width + w; + // Transpose channel and height, copying data from input to output + output[output_index] = in[input_index]; + } + } + } +})"; + +static const std::string transpose_cl_kernel_axis1 = + R"(__kernel void transpose_cl_axis1(__global const float* in, + __global float* output, + const int batch_size, + const int channels, + const int height, + const int width) { + // Calculate h and w from the global IDs + int h = get_global_id(0); + int w = get_global_id(1); + if (h < height && w < width) { + for (int c = 0; c < channels; ++c) { + for (int n = 0; n < batch_size; ++n) { + // Calculate the input and output indices + int input_index = n * (channels * height * width) + c * (height * width) + h * width + w; + int output_index = n * (channels * height * width) + c * (height * width) + w * height + h; + // Transpose height and width, copying data from input to output + output[output_index] = in[input_index]; + } + } + } +})"; + +static const std::string transpose_cl_kernel_axis2 = + R"(__kernel void transpose_cl_axis2(__global const float* in, + __global float* output, + const int batch_size, + const int channels, + const int height, + const int width) { + // Calculate c and w from the global IDs + int c = get_global_id(0); + int w = get_global_id(1); + if (c < channels && w < width) { + for (int h = 0; h < height; ++h) { + for (int n = 0; n < batch_size; ++n) { + // Calculate the input and output indices + int input_index = n * (channels * height * width) + c * (height * width) + h * width + w; + int output_index = n * (channels * height * width) + w * (height * channels) + h * channels + c; + // Transpose channel and width, copying data from input to output + output[output_index] = in[input_index]; + } + } + } +})"; + #ifdef ENABLE_FP16 static const std::string sgemv_cl_kernel_fp16_ = R"( @@ -244,6 +313,81 @@ static const std::string sscal_cl_kernel_fp16_ = unsigned int i = get_global_id(0); X[i] *= alpha; })"; + +static const std::string transpose_cl_kernel_fp16_axis0 = + R"( + #pragma OPENCL EXTENSION cl_khr_fp16 : enable + __kernel void transpose_cl_fp16_axis0(__global const half* in, + __global half* output, + const int batch_size, + const int channels, + const int height, + const int width) { + // Calculate h and w from the global IDs + int h = get_global_id(0); + int w = get_global_id(1); + if (h < height && w < width) { + for (int c = 0; c < channels; ++c) { + for (int n = 0; n < batch_size; ++n) { + // Calculate the input and output indices + int input_index = n * (channels * height * width) + c * (height * width) + h * width + w; + int output_index = n * (channels * height * width) + h * (channels * width) + c * width + w; + // Transpose channel and height, copying data from input to output + output[output_index] = in[input_index]; + } + } + } +})"; + +static const std::string transpose_cl_kernel_fp16_axis1 = + R"( + #pragma OPENCL EXTENSION cl_khr_fp16 : enable + __kernel void transpose_cl_fp16_axis1(__global const half* in, + __global half* output, + const int batch_size, + const int channels, + const int height, + const int width) { + // Calculate h and w from the global IDs + int h = get_global_id(0); + int w = get_global_id(1); + if (h < height && w < width) { + for (int c = 0; c < channels; ++c) { + for (int n = 0; n < batch_size; ++n) { + // Calculate the input and output indices + int input_index = n * (channels * height * width) + c * (height * width) + h * width + w; + int output_index = n * (channels * height * width) + c * (height * width) + w * height + h; + // Transpose height and width, copying data from input to output + output[output_index] = in[input_index]; + } + } + } +})"; + +static const std::string transpose_cl_kernel_fp16_axis2 = + R"( + #pragma OPENCL EXTENSION cl_khr_fp16 : enable + __kernel void transpose_cl_fp16_axis2(__global const half* in, + __global half* output, + const int batch_size, + const int channels, + const int height, + const int width) { + // Calculate c and w from the global IDs + int c = get_global_id(0); + int w = get_global_id(1); + if (c < channels && w < width) { + for (int h = 0; h < height; ++h) { + for (int n = 0; n < batch_size; ++n) { + // Calculate the input and output indices + int input_index = n * (channels * height * width) + c * (height * width) + h * width + w; + int output_index = n * (channels * height * width) + w * (height * channels) + h * channels + c; + // Transpose channel and width, copying data from input to output + output[output_index] = in[input_index]; + } + } + } +})"; #endif } // namespace nntrainer #endif /* __BLAS_KERNEL_INTERFACE_H__ */ diff --git a/nntrainer/tensor/cl_operations/blas_kernels.cpp b/nntrainer/tensor/cl_operations/blas_kernels.cpp index a8236988ad..8f319f411c 100644 --- a/nntrainer/tensor/cl_operations/blas_kernels.cpp +++ b/nntrainer/tensor/cl_operations/blas_kernels.cpp @@ -387,4 +387,268 @@ void sscal_cl(float *X, const unsigned int N, const float alpha) { } while (false); } + +void transpose_cl_axis0(const float *in, float *res, + unsigned int input_batch_size, + unsigned int input_channels, unsigned int input_height, + unsigned int input_width) { + + bool result = false; + + do { + ClContext::SharedPtrClKernel kernel_transpose_ptr = + cl_context_ref.registerClKernel(transpose_cl_kernel_axis0, + "transpose_cl_axis0"); + + if (!kernel_transpose_ptr) { + break; + } + + size_t dim_size = sizeof(float) * input_batch_size * input_height * + input_width * input_channels; + + opencl::Buffer inputA(cl_context_ref.context_inst_, dim_size, true, + nullptr); + + opencl::Buffer inOutRes(cl_context_ref.context_inst_, dim_size, true, + nullptr); + + result = inputA.WriteData(cl_context_ref.command_queue_inst_, in); + if (!result) { + break; + } + + result = inOutRes.WriteData(cl_context_ref.command_queue_inst_, res); + if (!result) { + break; + } + + result = + kernel_transpose_ptr->SetKernelArguments(0, &inputA, sizeof(cl_mem)); + if (!result) { + break; + } + + result = + kernel_transpose_ptr->SetKernelArguments(1, &inOutRes, sizeof(cl_mem)); + if (!result) { + break; + } + + result = kernel_transpose_ptr->SetKernelArguments(2, &input_batch_size, + sizeof(int)); + if (!result) { + break; + } + + result = + kernel_transpose_ptr->SetKernelArguments(3, &input_channels, sizeof(int)); + if (!result) { + break; + } + + result = + kernel_transpose_ptr->SetKernelArguments(4, &input_height, sizeof(int)); + if (!result) { + break; + } + + result = + kernel_transpose_ptr->SetKernelArguments(5, &input_width, sizeof(int)); + if (!result) { + break; + } + + const int work_groups_count[3] = {(int)input_height, (int)input_width, 1}; + const int work_group_size[3] = {32, 32, 1}; // test-value + + result = cl_context_ref.command_queue_inst_.DispatchCommand( + kernel_transpose_ptr, work_groups_count, work_group_size); + if (!result) { + break; + } + + result = inOutRes.ReadData(cl_context_ref.command_queue_inst_, res); + if (!result) { + break; + } + + } while (false); +} + +void transpose_cl_axis1(const float *in, float *res, + unsigned int input_batch_size, + unsigned int input_channels, unsigned int input_height, + unsigned int input_width) { + + bool result = false; + + do { + ClContext::SharedPtrClKernel kernel_transpose_ptr = + cl_context_ref.registerClKernel(transpose_cl_kernel_axis1, + "transpose_cl_axis1"); + + if (!kernel_transpose_ptr) { + break; + } + + size_t dim_size = sizeof(float) * input_batch_size * input_height * + input_width * input_channels; + + opencl::Buffer inputA(cl_context_ref.context_inst_, dim_size, true, + nullptr); + + opencl::Buffer inOutRes(cl_context_ref.context_inst_, dim_size, true, + nullptr); + + result = inputA.WriteData(cl_context_ref.command_queue_inst_, in); + if (!result) { + break; + } + + result = inOutRes.WriteData(cl_context_ref.command_queue_inst_, res); + if (!result) { + break; + } + + result = + kernel_transpose_ptr->SetKernelArguments(0, &inputA, sizeof(cl_mem)); + if (!result) { + break; + } + + result = + kernel_transpose_ptr->SetKernelArguments(1, &inOutRes, sizeof(cl_mem)); + if (!result) { + break; + } + + result = kernel_transpose_ptr->SetKernelArguments(2, &input_batch_size, + sizeof(int)); + if (!result) { + break; + } + + result = + kernel_transpose_ptr->SetKernelArguments(3, &input_channels, sizeof(int)); + if (!result) { + break; + } + + result = + kernel_transpose_ptr->SetKernelArguments(4, &input_height, sizeof(int)); + if (!result) { + break; + } + + result = + kernel_transpose_ptr->SetKernelArguments(5, &input_width, sizeof(int)); + if (!result) { + break; + } + + const int work_groups_count[3] = {(int)input_height, (int)input_width, 1}; + const int work_group_size[3] = {32, 32, 1}; // test-value + + result = cl_context_ref.command_queue_inst_.DispatchCommand( + kernel_transpose_ptr, work_groups_count, work_group_size); + if (!result) { + break; + } + + result = inOutRes.ReadData(cl_context_ref.command_queue_inst_, res); + if (!result) { + break; + } + + } while (false); +} + +void transpose_cl_axis2(const float *in, float *res, + unsigned int input_batch_size, + unsigned int input_channels, unsigned int input_height, + unsigned int input_width) { + + bool result = false; + + do { + ClContext::SharedPtrClKernel kernel_transpose_ptr = + cl_context_ref.registerClKernel(transpose_cl_kernel_axis2, + "transpose_cl_axis2"); + + if (!kernel_transpose_ptr) { + break; + } + + size_t dim_size = sizeof(float) * input_batch_size * input_height * + input_width * input_channels; + + opencl::Buffer inputA(cl_context_ref.context_inst_, dim_size, true, + nullptr); + + opencl::Buffer inOutRes(cl_context_ref.context_inst_, dim_size, true, + nullptr); + + result = inputA.WriteData(cl_context_ref.command_queue_inst_, in); + if (!result) { + break; + } + + result = inOutRes.WriteData(cl_context_ref.command_queue_inst_, res); + if (!result) { + break; + } + + result = + kernel_transpose_ptr->SetKernelArguments(0, &inputA, sizeof(cl_mem)); + if (!result) { + break; + } + + result = + kernel_transpose_ptr->SetKernelArguments(1, &inOutRes, sizeof(cl_mem)); + if (!result) { + break; + } + + result = kernel_transpose_ptr->SetKernelArguments(2, &input_batch_size, + sizeof(int)); + if (!result) { + break; + } + + result = + kernel_transpose_ptr->SetKernelArguments(3, &input_channels, sizeof(int)); + if (!result) { + break; + } + + result = + kernel_transpose_ptr->SetKernelArguments(4, &input_height, sizeof(int)); + if (!result) { + break; + } + + result = + kernel_transpose_ptr->SetKernelArguments(5, &input_width, sizeof(int)); + if (!result) { + break; + } + + const int work_groups_count[3] = {(int)input_channels, (int)input_width, 1}; + const int work_group_size[3] = {32, 32, 1}; // test-value + + result = cl_context_ref.command_queue_inst_.DispatchCommand( + kernel_transpose_ptr, work_groups_count, work_group_size); + if (!result) { + break; + } + + result = inOutRes.ReadData(cl_context_ref.command_queue_inst_, res); + if (!result) { + break; + } + + } while (false); +} } // namespace nntrainer diff --git a/nntrainer/tensor/cl_operations/blas_kernels.h b/nntrainer/tensor/cl_operations/blas_kernels.h index 247314740a..d7392d2ded 100644 --- a/nntrainer/tensor/cl_operations/blas_kernels.h +++ b/nntrainer/tensor/cl_operations/blas_kernels.h @@ -85,6 +85,51 @@ void addition_cl(const float *input, float *res, unsigned int size); */ void sscal_cl(float *X, const unsigned int N, const float alpha); +/** + * @brief transpose computation about channels and height + * @param[in] input float * for Input Tensor + * @param[in] res float * for Output Tensor + * @param[in] input_batch_size represents the number of samples in the input + * tensor + * @param[in] input_channels represents the channels of the input tensor + * @param[in] input_height represents the height of the input tensor + * @param[in] input_width represents the width of the input tensor + */ +void transpose_cl_axis0(const float *in, float *res, + unsigned int input_batch_size, + unsigned int input_channels, unsigned int input_height, + unsigned int input_width); + +/** + * @brief transpose computation about height and width + * @param[in] input float * for Input Tensor + * @param[in] res float * for Output Tensor + * @param[in] input_batch_size represents the number of samples in the input + * tensor + * @param[in] input_channels represents the channels of the input tensor + * @param[in] input_height represents the height of the input tensor + * @param[in] input_width represents the width of the input tensor + */ +void transpose_cl_axis1(const float *in, float *res, + unsigned int input_batch_size, + unsigned int input_channels, unsigned int input_height, + unsigned int input_width); + +/** + * @brief transpose computation about channels and width + * @param[in] input float * for Input Tensor + * @param[in] res float * for Output Tensor + * @param[in] input_batch_size represents the number of samples in the input + * tensor + * @param[in] input_channels represents the channels of the input tensor + * @param[in] input_height represents the height of the input tensor + * @param[in] input_width represents the width of the input tensor + */ +void transpose_cl_axis2(const float *in, float *res, + unsigned int input_batch_size, + unsigned int input_channels, unsigned int input_height, + unsigned int input_width); + #ifdef ENABLE_FP16 /** @@ -148,6 +193,51 @@ void addition_cl(const __fp16 *input, __fp16 *res, unsigned int size); * @param[in] context RunLayerContext reference */ void sscal_cl(__fp16 *X, const unsigned int N, const float alpha); + +/** + * @brief transpose computation about channels and height + * @param[in] input fp16 * for Input Tensor + * @param[in] res fp16 * for Output Tensor + * @param[in] input_batch_size represents the number of samples in the input + * tensor + * @param[in] input_channels represents the channels of the input tensor + * @param[in] input_height represents the height of the input tensor + * @param[in] input_width represents the width of the input tensor + */ +void transpose_cl_axis0(const __fp16 *in, __fp16 *res, + unsigned int input_batch_size, + unsigned int input_channels, unsigned int input_height, + unsigned int input_width); + +/** + * @brief transpose computation about height and width + * @param[in] input fp16 * for Input Tensor + * @param[in] res fp16 * for Output Tensor + * @param[in] input_batch_size represents the number of samples in the input + * tensor + * @param[in] input_channels represents the channels of the input tensor + * @param[in] input_height represents the height of the input tensor + * @param[in] input_width represents the width of the input tensor + */ +void transpose_cl_axis1(const __fp16 *in, __fp16 *res, + unsigned int input_batch_size, + unsigned int input_channels, unsigned int input_height, + unsigned int input_width); + +/** + * @brief transpose computation about channels and width + * @param[in] input fp16 * for Input Tensor + * @param[in] res fp16 * for Output Tensor + * @param[in] input_batch_size represents the number of samples in the input + * tensor + * @param[in] input_channels represents the channels of the input tensor + * @param[in] input_height represents the height of the input tensor + * @param[in] input_width represents the width of the input tensor + */ +void transpose_cl_axis2(const __fp16 *in, __fp16 *res, + unsigned int input_batch_size, + unsigned int input_channels, unsigned int input_height, + unsigned int input_width); #endif } // namespace nntrainer diff --git a/nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp b/nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp index 6aa7ccb6e2..583be672a5 100644 --- a/nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp +++ b/nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp @@ -402,4 +402,268 @@ void sscal_cl(__fp16 *X, const unsigned int N, const float alpha) { } while (false); } + +void transpose_cl_axis0(const __fp16 *in, __fp16 *res, + unsigned int input_batch_size, + unsigned int input_channels, unsigned int input_height, + unsigned int input_width) { + + bool result = false; + + do { + ClContext::SharedPtrClKernel kernel_transpose_fp_16_ptr = + cl_context_ref.registerClKernel(transpose_cl_kernel_fp16_axis0, + "transpose_cl_fp16_axis0"); + + if (!kernel_transpose_fp_16_ptr) { + break; + } + + size_t dim_size = sizeof(__fp16) * input_batch_size * input_height * + input_width * input_channels; + + opencl::Buffer inputA(cl_context_ref.context_inst_, dim_size, true, + nullptr); + + opencl::Buffer inOutRes(cl_context_ref.context_inst_, dim_size, true, + nullptr); + + result = inputA.WriteData(cl_context_ref.command_queue_inst_, in); + if (!result) { + break; + } + + result = inOutRes.WriteData(cl_context_ref.command_queue_inst_, res); + if (!result) { + break; + } + + result = kernel_transpose_fp_16_ptr->SetKernelArguments(0, &inputA, + sizeof(cl_mem)); + if (!result) { + break; + } + + result = kernel_transpose_fp_16_ptr->SetKernelArguments(1, &inOutRes, + sizeof(cl_mem)); + if (!result) { + break; + } + + result = kernel_transpose_fp_16_ptr->SetKernelArguments( + 2, &input_batch_size, sizeof(int)); + if (!result) { + break; + } + + result = kernel_transpose_fp_16_ptr->SetKernelArguments(3, &input_channels, + sizeof(int)); + if (!result) { + break; + } + + result = kernel_transpose_fp_16_ptr->SetKernelArguments(4, &input_height, + sizeof(int)); + if (!result) { + break; + } + + result = kernel_transpose_fp_16_ptr->SetKernelArguments(5, &input_width, + sizeof(int)); + if (!result) { + break; + } + + const int work_groups_count[3] = {(int)input_height, (int)input_width, 1}; + const int work_group_size[3] = {32, 32, 1}; // test-value + + result = cl_context_ref.command_queue_inst_.DispatchCommand( + kernel_transpose_fp_16_ptr, work_groups_count, work_group_size); + if (!result) { + break; + } + + result = inOutRes.ReadData(cl_context_ref.command_queue_inst_, res); + if (!result) { + break; + } + + } while (false); +} + +void transpose_cl_axis1(const __fp16 *in, __fp16 *res, + unsigned int input_batch_size, + unsigned int input_channels, unsigned int input_height, + unsigned int input_width) { + + bool result = false; + + do { + ClContext::SharedPtrClKernel kernel_transpose_fp_16_ptr = + cl_context_ref.registerClKernel(transpose_cl_kernel_fp16_axis1, + "transpose_cl_fp16_axis1"); + + if (!kernel_transpose_fp_16_ptr) { + break; + } + + size_t dim_size = sizeof(__fp16) * input_batch_size * input_height * + input_width * input_channels; + + opencl::Buffer inputA(cl_context_ref.context_inst_, dim_size, true, + nullptr); + + opencl::Buffer inOutRes(cl_context_ref.context_inst_, dim_size, true, + nullptr); + + result = inputA.WriteData(cl_context_ref.command_queue_inst_, in); + if (!result) { + break; + } + + result = inOutRes.WriteData(cl_context_ref.command_queue_inst_, res); + if (!result) { + break; + } + + result = kernel_transpose_fp_16_ptr->SetKernelArguments(0, &inputA, + sizeof(cl_mem)); + if (!result) { + break; + } + + result = kernel_transpose_fp_16_ptr->SetKernelArguments(1, &inOutRes, + sizeof(cl_mem)); + if (!result) { + break; + } + + result = kernel_transpose_fp_16_ptr->SetKernelArguments( + 2, &input_batch_size, sizeof(int)); + if (!result) { + break; + } + + result = kernel_transpose_fp_16_ptr->SetKernelArguments(3, &input_channels, + sizeof(int)); + if (!result) { + break; + } + + result = kernel_transpose_fp_16_ptr->SetKernelArguments(4, &input_height, + sizeof(int)); + if (!result) { + break; + } + + result = kernel_transpose_fp_16_ptr->SetKernelArguments(5, &input_width, + sizeof(int)); + if (!result) { + break; + } + + const int work_groups_count[3] = {(int)input_height, (int)input_width, 1}; + const int work_group_size[3] = {32, 32, 1}; // test-value + + result = cl_context_ref.command_queue_inst_.DispatchCommand( + kernel_transpose_fp_16_ptr, work_groups_count, work_group_size); + if (!result) { + break; + } + + result = inOutRes.ReadData(cl_context_ref.command_queue_inst_, res); + if (!result) { + break; + } + + } while (false); +} + +void transpose_cl_axis2(const __fp16 *in, __fp16 *res, + unsigned int input_batch_size, + unsigned int input_channels, unsigned int input_height, + unsigned int input_width) { + + bool result = false; + + do { + ClContext::SharedPtrClKernel kernel_transpose_fp_16_ptr = + cl_context_ref.registerClKernel(transpose_cl_kernel_fp16_axis2, + "transpose_cl_fp16_axis2"); + + if (!kernel_transpose_fp_16_ptr) { + break; + } + + size_t dim_size = sizeof(__fp16) * input_batch_size * input_height * + input_width * input_channels; + + opencl::Buffer inputA(cl_context_ref.context_inst_, dim_size, true, + nullptr); + + opencl::Buffer inOutRes(cl_context_ref.context_inst_, dim_size, true, + nullptr); + + result = inputA.WriteData(cl_context_ref.command_queue_inst_, in); + if (!result) { + break; + } + + result = inOutRes.WriteData(cl_context_ref.command_queue_inst_, res); + if (!result) { + break; + } + + result = kernel_transpose_fp_16_ptr->SetKernelArguments(0, &inputA, + sizeof(cl_mem)); + if (!result) { + break; + } + + result = kernel_transpose_fp_16_ptr->SetKernelArguments(1, &inOutRes, + sizeof(cl_mem)); + if (!result) { + break; + } + + result = kernel_transpose_fp_16_ptr->SetKernelArguments( + 2, &input_batch_size, sizeof(int)); + if (!result) { + break; + } + + result = kernel_transpose_fp_16_ptr->SetKernelArguments(3, &input_channels, + sizeof(int)); + if (!result) { + break; + } + + result = kernel_transpose_fp_16_ptr->SetKernelArguments(4, &input_height, + sizeof(int)); + if (!result) { + break; + } + + result = kernel_transpose_fp_16_ptr->SetKernelArguments(5, &input_width, + sizeof(int)); + if (!result) { + break; + } + + const int work_groups_count[3] = {(int)input_channels, (int)input_width, 1}; + const int work_group_size[3] = {32, 32, 1}; // test-value + + result = cl_context_ref.command_queue_inst_.DispatchCommand( + kernel_transpose_fp_16_ptr, work_groups_count, work_group_size); + if (!result) { + break; + } + + result = inOutRes.ReadData(cl_context_ref.command_queue_inst_, res); + if (!result) { + break; + } + + } while (false); +} } // namespace nntrainer diff --git a/test/input_gen/gen_layer_tests.py b/test/input_gen/gen_layer_tests.py index 1300dcd8d7..180402fc62 100644 --- a/test/input_gen/gen_layer_tests.py +++ b/test/input_gen/gen_layer_tests.py @@ -954,3 +954,27 @@ def call(self, inputs): rms_normtest_fp16 = RMSNorm() record_single(rms_normtest,(2,3,3,3),"rms_normtest") record_single_fp16(rms_normtest_fp16,(2,3,3,3),"rms_normtest_fp16_new") + + def transpose_axis0(tensor, batch_size, input_channel, input_height, input_width): + output_shape = (batch_size, input_channel, input_height, input_width) + return tf.transpose(tensor, perm=[0, 2, 1, 3]) + + transpose_layer_axis0 = tf.keras.layers.Lambda(lambda x: transpose_axis0(x, 2, 3, 3, 3)) + record_single(transpose_layer_axis0, (2, 3, 3, 3), "transpose_axis0", input_type="float") + record_single_fp16(transpose_layer_axis0, (2, 3, 3, 3), "transpose_fp16_axis0", input_type="float") + + def transpose_axis1(tensor, batch_size, input_channel, input_height, input_width): + output_shape = (batch_size, input_channel, input_height, input_width) + return tf.transpose(tensor, perm=[0, 1, 3, 2]) + + transpose_layer_axis1 = tf.keras.layers.Lambda(lambda x: transpose_axis1(x, 2, 3, 3, 3)) + record_single(transpose_layer_axis1, (2, 3, 3, 3), "transpose_axis1", input_type="float") + record_single_fp16(transpose_layer_axis1, (2, 3, 3, 3), "transpose_fp16_axis1", input_type="float") + + def transpose_axis2(tensor, batch_size, input_channel, input_height, input_width): + output_shape = (batch_size, input_channel, input_height, input_width) + return tf.transpose(tensor, perm=[0, 3, 2, 1]) + + transpose_layer_axis2 = tf.keras.layers.Lambda(lambda x: transpose_axis2(x, 2, 3, 3, 3)) + record_single(transpose_layer_axis2, (2, 3, 3, 3), "transpose_axis2", input_type="float") + record_single_fp16(transpose_layer_axis2, (2, 3, 3, 3), "transpose_fp16_axis2", input_type="float") diff --git a/test/jni/Android.mk b/test/jni/Android.mk index faaba46f45..9e1be9dd9e 100644 --- a/test/jni/Android.mk +++ b/test/jni/Android.mk @@ -442,6 +442,7 @@ LOCAL_SRC_FILES := \ ../unittest/layers/unittest_layer_node.cpp \ ../unittest/layers/unittest_layers.cpp \ ../unittest/layers/unittest_layers_impl.cpp \ + ../unittest/layers/unittest_layers_transpose_cl.cpp \ ../unittest/layers/unittest_layers_concat_cl.cpp \ ../unittest/layers/unittest_layers_swiglu_cl.cpp \ ../unittest/layers/unittest_layers_fully_connected_cl.cpp \ diff --git a/test/unittest/layers/unittest_layers_transpose_cl.cpp b/test/unittest/layers/unittest_layers_transpose_cl.cpp new file mode 100644 index 0000000000..b7568c4365 --- /dev/null +++ b/test/unittest/layers/unittest_layers_transpose_cl.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Niket Agarwal + * + * @file unittest_layers_transpose_cl.cpp + * @date 31 July 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author Niket Agarwal + * @bug No known bugs except for NYI items + * @brief Transpose Layer Test + */ +#include + +#include + +#include +#include + +auto semantic_transpose_gpu = LayerSemanticsParamType( + nntrainer::createLayer, + nntrainer::TransposeLayerCl::type, {}, + LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 1); + +GTEST_PARAMETER_TEST(TransposeGPU, LayerSemanticsGpu, + ::testing::Values(semantic_transpose_gpu)); + +auto transpose_basic_plain_axis0 = + LayerGoldenTestParamType(nntrainer::createLayer, + {}, "2:3:3:3", "transpose_axis0.nnlayergolden", + LayerGoldenTestParamOptions::SKIP_CALC_DERIV | + LayerGoldenTestParamOptions::SKIP_CALC_GRAD | + LayerGoldenTestParamOptions::USE_INC_FORWARD, + "nchw", "fp32", "fp32"); + +// auto transpose_basic_plain_axis1 = +// LayerGoldenTestParamType(nntrainer::createLayer, +// {}, "2:3:3:3", "transpose_axis1.nnlayergolden", +// LayerGoldenTestParamOptions::SKIP_CALC_DERIV | +// LayerGoldenTestParamOptions::SKIP_CALC_GRAD | +// LayerGoldenTestParamOptions::USE_INC_FORWARD, +// "nchw", "fp32", "fp32"); + +// auto transpose_basic_plain_axis2 = +// LayerGoldenTestParamType(nntrainer::createLayer, +// {}, "2:3:3:3", "transpose_axis2.nnlayergolden", +// LayerGoldenTestParamOptions::SKIP_CALC_DERIV | +// LayerGoldenTestParamOptions::SKIP_CALC_GRAD | +// LayerGoldenTestParamOptions::USE_INC_FORWARD, +// "nchw", "fp32", "fp32"); + +GTEST_PARAMETER_TEST(TransposeGPU, LayerGoldenTest, + ::testing::Values(transpose_basic_plain_axis0)); + +#ifdef ENABLE_FP16 +auto transpose_basic_plain_w16a16_axis0 = + LayerGoldenTestParamType(nntrainer::createLayer, + {}, "2:3:3:3", "transpose_fp16_axis0.nnlayergolden", + LayerGoldenTestParamOptions::SKIP_CALC_DERIV | + LayerGoldenTestParamOptions::SKIP_CALC_GRAD | + LayerGoldenTestParamOptions::USE_INC_FORWARD, + "nchw", "fp16", "fp16"); + +// auto transpose_basic_plain_w16a16_axis1 = +// LayerGoldenTestParamType(nntrainer::createLayer, +// {}, "2:3:3:3", +// "transpose_fp16_axis1.nnlayergolden", +// LayerGoldenTestParamOptions::SKIP_CALC_DERIV | +// LayerGoldenTestParamOptions::SKIP_CALC_GRAD | +// LayerGoldenTestParamOptions::USE_INC_FORWARD, +// "nchw", "fp16", "fp16"); + +// auto transpose_basic_plain_w16a16_axis2 = +// LayerGoldenTestParamType(nntrainer::createLayer, +// {}, "2:3:3:3", +// "transpose_fp16_axis2.nnlayergolden", +// LayerGoldenTestParamOptions::SKIP_CALC_DERIV | +// LayerGoldenTestParamOptions::SKIP_CALC_GRAD | +// LayerGoldenTestParamOptions::USE_INC_FORWARD, +// "nchw", "fp16", "fp16"); + +GTEST_PARAMETER_TEST(TransposeGPU16, LayerGoldenTest, + ::testing::Values(transpose_basic_plain_w16a16_axis0)); +#endif