From a1f2f8af33692469db2a7cfae188a1c65bc1dce6 Mon Sep 17 00:00:00 2001 From: Debadri Samaddar Date: Wed, 5 Jun 2024 15:50:35 +0530 Subject: [PATCH] [refactor] Moved blas_kernels to tensor directory Moved common OpenCL blas kernels to tensor directory. Added pre processing functions as common that can be re-used. Signed-off-by: Debadri Samaddar --- nntrainer/layers/cl_layers/fc_layer_cl.cpp | 114 +---------- nntrainer/layers/cl_layers/fc_layer_cl.h | 16 -- nntrainer/layers/cl_layers/meson.build | 5 - .../cl_operations/blas_kernel_interface.cpp | 189 ++++++++++++++++++ .../cl_operations/blas_kernel_interface.h | 48 +++++ .../cl_operations}/blas_kernels.cpp | 12 +- .../cl_operations}/blas_kernels.h | 62 +++--- .../cl_operations}/blas_kernels_fp16.cpp | 12 +- nntrainer/tensor/cl_operations/meson.build | 20 ++ nntrainer/tensor/meson.build | 14 +- 10 files changed, 310 insertions(+), 182 deletions(-) create mode 100644 nntrainer/tensor/cl_operations/blas_kernel_interface.cpp create mode 100644 nntrainer/tensor/cl_operations/blas_kernel_interface.h rename nntrainer/{layers/cl_layers => tensor/cl_operations}/blas_kernels.cpp (95%) rename nntrainer/{layers/cl_layers => tensor/cl_operations}/blas_kernels.h (98%) rename nntrainer/{layers/cl_layers => tensor/cl_operations}/blas_kernels_fp16.cpp (95%) create mode 100644 nntrainer/tensor/cl_operations/meson.build diff --git a/nntrainer/layers/cl_layers/fc_layer_cl.cpp b/nntrainer/layers/cl_layers/fc_layer_cl.cpp index 78c152c88a..890450bebe 100644 --- a/nntrainer/layers/cl_layers/fc_layer_cl.cpp +++ b/nntrainer/layers/cl_layers/fc_layer_cl.cpp @@ -12,7 +12,7 @@ * */ -#include +#include #include #include #include @@ -126,9 +126,9 @@ void FullyConnectedLayerCl::forwarding(RunLayerContext &context, weight.dequantize(weight_, axis); - fcDotProcess(input_, weight_, hidden_, context); + dotCl(input_, weight_, hidden_, context); } else { - fcDotProcess(input_, weight, hidden_, context); + dotCl(input_, weight, hidden_, context); } if (auto &disable_bias = std::get(*layer_impl_props); @@ -138,112 +138,6 @@ void FullyConnectedLayerCl::forwarding(RunLayerContext &context, } } -void FullyConnectedLayerCl::fcDotProcess(Tensor const &input, - Tensor const &weight, Tensor &result, - RunLayerContext &context) { - // to do: - // NNTR_THROW_IF(!contiguous, std::invalid_argument) - // << getName() << " is not contiguous. Cannot dot product."; - - unsigned int dim1, dim2, mdim1, mdim2; - if (input.getFormat() == Tformat::NHWC) { - dim1 = input.batch() * input.height() * input.width(); - dim2 = input.channel(); - mdim1 = weight.batch() * weight.height() * weight.width(); - mdim2 = weight.channel(); - } else { - dim1 = input.batch() * input.channel() * input.height(); - dim2 = input.width(); - mdim1 = weight.batch() * weight.channel() * weight.height(); - mdim2 = weight.width(); - } - - unsigned int M, N, K, lda, ldb, ldc; - if (dim2 != mdim1) - throw std::runtime_error("Error: incompatible dimensions for dot product"); - K = mdim1; /** == dim2 */ - N = mdim2; - M = dim1; - if (input.getFormat() == Tformat::NHWC) { - CREATE_IF_EMPTY_DIMS(result, input.batch(), N, input.height(), - input.width(), - input.getTensorType()); // NHWC Result Tensor - } else { - CREATE_IF_EMPTY_DIMS(result, input.batch(), input.channel(), input.height(), - N, input.getTensorType()); - } - - lda = dim2; - ldb = mdim2; - ldc = - (input.getFormat() == Tformat::NHWC) ? result.channel() : result.width(); - - if (input.getDataType() == ml::train::TensorDim::DataType::FP32) { - const float *data = input.getData(); - const float *mdata = weight.getData(); - float *rdata = result.getData(); - - /// shortcut handling in case of vector - /// for vector, (1 * K) == (K * 1) in current memory layout... - /// and plaese note that N, K, M is a fixed place holder after considering - /// transpose. - /// For example, there is no case like (1 * K) X (1 * K) while - /// (1 * K) X (1 * M) can be a case - /// case1: (1 * K) X (K * 1) - if (M == 1 && N == 1) { - *rdata = dot_cl(data, mdata, K, context) + (*rdata); - } - /// case2: (M * K) X (K * 1) - else if (N == 1) { - sgemv_cl(data, mdata, rdata, dim1, dim2, lda, context); - } - /// case3: (1 * K) X (K * N) = 1 * N = R - /// = R^T = (K * N) ^T * (1 * K) ^T = (N * K) * (K * 1) = (N * K) * (1 * K) - /// Effectively a translation of sgemv - else if (M == 1) { - sgemv_cl(mdata, data, rdata, mdim1, mdim2, ldb, context); - } - /// case others: use gemm - else { - sgemm_cl(data, mdata, rdata, M, N, K, lda, ldb, ldc, context); - } - } else if (input.getDataType() == ml::train::TensorDim::DataType::FP16) { -#ifdef ENABLE_FP16 - const _FP16 *data = input.getData<_FP16>(); - const _FP16 *mdata = weight.getData<_FP16>(); - _FP16 *rdata = result.getData<_FP16>(); - const float alpha = 1.0f; - - /// shortcut handling in case of vector - /// for vector, (1 * K) == (K * 1) in current memory layout... - /// and plaese note that N, K, M is a fixed place holder after considering - /// transpose. - /// For example, there is no case like (1 * K) X (1 * K) while - /// (1 * K) X (1 * M) can be a case - /// case1: (1 * K) X (K * 1) - if (M == 1 && N == 1) { - *rdata = dot_cl(data, mdata, K, context) + (*rdata); - } - /// case2: (M * K) X (K * 1) - else if (N == 1) { - sgemv_cl(data, mdata, rdata, dim1, dim2, lda, context); - } - /// case3: (1 * K) X (K * N) = 1 * N = R - /// = R^T = (K * N) ^T * (1 * K) ^T = (N * K) * (K * 1) = (N * K) * (1 * K) - /// Effectively a translation of sgemv - else if (M == 1) { - sgemv_cl(mdata, data, rdata, mdim1, mdim2, ldb, context); - } - /// case others: use sgemm - else { - sgemm_cl(data, mdata, rdata, M, N, K, lda, ldb, ldc, context); - } -#else - throw std::invalid_argument("Error: enable-fp16 is not enabled"); -#endif - } -} - void FullyConnectedLayerCl::incremental_forwarding(RunLayerContext &context, unsigned int from, unsigned int to, @@ -276,7 +170,7 @@ void FullyConnectedLayerCl::incremental_forwarding(RunLayerContext &context, Tensor input_step = input_.getSharedDataTensor(input_step_dim, 0, true); Tensor hidden_step = hidden_.getSharedDataTensor(hidden_step_dim, 0, true); - fcDotProcess(input_step, weight, hidden_step, context); + dotCl(input_step, weight, hidden_step, context); if (auto &disable_bias = std::get(*layer_impl_props); disable_bias.empty() || disable_bias.get() == false) { diff --git a/nntrainer/layers/cl_layers/fc_layer_cl.h b/nntrainer/layers/cl_layers/fc_layer_cl.h index c94ecb22d7..391a83f734 100644 --- a/nntrainer/layers/cl_layers/fc_layer_cl.h +++ b/nntrainer/layers/cl_layers/fc_layer_cl.h @@ -19,12 +19,6 @@ #include #include -#define CREATE_IF_EMPTY_DIMS(tensor, ...) \ - do { \ - if (tensor.empty()) \ - tensor = Tensor(__VA_ARGS__); \ - } while (0); - namespace nntrainer { /** @@ -96,16 +90,6 @@ class FullyConnectedLayerCl : public LayerImpl { return FullyConnectedLayerCl::type; }; - /** - * @brief Process data and dimensions for dot operation used in fc_layer - * @param[in] input Tensor - * @param[in] weight Tensor - * @param[in] result Tensor - * @param[in] RunLayerContext reference - */ - void fcDotProcess(Tensor const &input, Tensor const &weight, Tensor &result, - RunLayerContext &context); - /** * @copydoc Layer::supportBackwarding() */ diff --git a/nntrainer/layers/cl_layers/meson.build b/nntrainer/layers/cl_layers/meson.build index fd8ed3cae9..5c6ad1358f 100644 --- a/nntrainer/layers/cl_layers/meson.build +++ b/nntrainer/layers/cl_layers/meson.build @@ -1,12 +1,7 @@ cl_layer_sources = [ 'fc_layer_cl.cpp', - 'blas_kernels.cpp', ] -if get_option('enable-fp16') - cl_layer_sources += 'blas_kernels_fp16.cpp' -endif - foreach s : cl_layer_sources nntrainer_sources += meson.current_source_dir() / s endforeach diff --git a/nntrainer/tensor/cl_operations/blas_kernel_interface.cpp b/nntrainer/tensor/cl_operations/blas_kernel_interface.cpp new file mode 100644 index 0000000000..852b482529 --- /dev/null +++ b/nntrainer/tensor/cl_operations/blas_kernel_interface.cpp @@ -0,0 +1,189 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Debadri Samaddar + * + * @file blas_kernel_interface.cpp + * @date 5 June 2024 + * @brief Interface for blas OpenCL kernels + * @see https://github.com/nnstreamer/nntrainer + * @author Debadri Samaddar + * @bug No known bugs except for NYI items + * + */ + +#include +#include + +namespace nntrainer { +void dotBatchedCl(Tensor const &input, Tensor const &m, Tensor &result, + RunLayerContext &context, bool trans, bool trans_m) { + if (!result.isAllocated()) + throw std::invalid_argument( + "Output tensor must be preallocated for dotBatched operation"); + for (unsigned int b = 0; b < input.batch(); b++) { + /** @todo try using transpose to speedup the operation */ + const Tensor this_b = input.getBatchSlice(b, 1); + Tensor m_b = m.getBatchSlice(b, 1); + Tensor result_b = result.getBatchSlice(b, 1); + + dotCl(this_b, m_b, result_b, context, trans, trans_m); + } +} + +void dotCl(Tensor const &input, Tensor const &m, Tensor &result, + RunLayerContext &context, bool trans, bool trans_m) { + unsigned int dim1, dim2, mdim1, mdim2; + if (input.getFormat() == Tformat::NHWC) { + dim1 = input.batch() * input.height() * input.width(); + dim2 = input.channel(); + mdim1 = m.batch() * m.height() * m.width(); + mdim2 = m.channel(); + } else { + dim1 = input.batch() * input.channel() * input.height(); + dim2 = input.width(); + mdim1 = m.batch() * m.channel() * m.height(); + mdim2 = m.width(); + } + + unsigned int M, N, K, lda, ldb, ldc; + + if (!trans && !trans_m) { + if (dim2 != mdim1) + throw std::runtime_error( + "Error: incompatible dimensions for dot product"); + K = mdim1; /** == dim2 */ + N = mdim2; + M = dim1; + if (input.getFormat() == Tformat::NHWC) { + CREATE_IF_EMPTY_DIMS(result, input.batch(), N, input.height(), + input.width(), + input.getTensorType()); // NHWC Result Tensor + } else { + CREATE_IF_EMPTY_DIMS(result, input.batch(), input.channel(), + input.height(), N, input.getTensorType()); + } + } else if (!trans && trans_m) { + if (dim2 != mdim2) + throw std::runtime_error( + "Error: incompatible dimensions for dot product"); + K = mdim2; /** == dim2 */ + N = mdim1; + M = dim1; + if (input.getFormat() == Tformat::NHWC) { + CREATE_IF_EMPTY_DIMS(result, input.batch(), N, input.height(), + input.width(), input.getTensorType()); + } else { + CREATE_IF_EMPTY_DIMS(result, input.batch(), input.channel(), + input.height(), N, input.getTensorType()); + } + } else if (trans && !trans_m) { + if (dim1 != mdim1) + throw std::runtime_error( + "Error: incompatible dimensions for dot product"); + K = mdim1; /** == dim1 */ + N = mdim2; + M = dim2; + if (input.getFormat() == Tformat::NHWC) { + CREATE_IF_EMPTY_DIMS(result, 1, N, M, 1, input.getTensorType()); + } else { + CREATE_IF_EMPTY_DIMS(result, 1, 1, M, N, input.getTensorType()); + } + } else { + if (dim1 != mdim2) + throw std::runtime_error( + "Error: incompatible dimensions for dot product"); + K = mdim2; /** == dim1 */ + N = mdim1; + M = dim2; + if (input.getFormat() == Tformat::NHWC) { + CREATE_IF_EMPTY_DIMS(result, 1, N, M, 1, input.getTensorType()); + } else { + CREATE_IF_EMPTY_DIMS(result, 1, 1, M, N, input.getTensorType()); + } + } + + lda = dim2; + ldb = mdim2; + ldc = + (input.getFormat() == Tformat::NHWC) ? result.channel() : result.width(); + + if (input.getDataType() == ml::train::TensorDim::DataType::FP32) { + const float *data = input.getData(); + const float *mdata = m.getData(); + float *rdata = result.getData(); + enum CBLAS_TRANSPOSE transA = trans ? CblasTrans : CblasNoTrans; + enum CBLAS_TRANSPOSE transB = trans_m ? CblasTrans : CblasNoTrans; + + /// shortcut handling in case of vector + /// for vector, (1 * K) == (K * 1) in current memory layout... + /// and plaese note that N, K, M is a fixed place holder after considering + /// transpose. + /// For example, there is no case like (1 * K) X (1 * K) while + /// (1 * K) X (1 * M) can be a case + /// case1: (1 * K) X (K * 1) + if (M == 1 && N == 1) { + *rdata = dot_cl(data, mdata, K, context) + (*rdata); + } + /// case2: (M * K) X (K * 1) + else if (N == 1) { + transA ? sgemv_cl(data, mdata, rdata, dim2, dim1, lda, context) + : sgemv_cl(data, mdata, rdata, dim1, dim2, lda, context); + } + /// case3: (1 * K) X (K * N) = 1 * N = R + /// = R^T = (K * N) ^T * (1 * K) ^T = (N * K) * (K * 1) = (N * K) * (1 * K) + /// Effectively a translation of sgemv + else if (M == 1) { + transB = transB == CblasTrans ? CblasNoTrans : CblasTrans; + transB ? sgemv_cl(mdata, data, rdata, mdim2, mdim1, ldb, context) + : sgemv_cl(mdata, data, rdata, mdim1, mdim2, ldb, context); + } + /// case others: use gemm + else { + // transA == false, transB == false + sgemm_cl(data, mdata, rdata, M, N, K, lda, ldb, ldc, context); + // todo: other condition implementations + } + } else if (input.getDataType() == ml::train::TensorDim::DataType::FP16) { +#ifdef ENABLE_FP16 + const _FP16 *data = input.getData<_FP16>(); + const _FP16 *mdata = m.getData<_FP16>(); + _FP16 *rdata = result.getData<_FP16>(); + enum CBLAS_TRANSPOSE transA = trans ? CblasTrans : CblasNoTrans; + enum CBLAS_TRANSPOSE transB = trans_m ? CblasTrans : CblasNoTrans; + + /// shortcut handling in case of vector + /// for vector, (1 * K) == (K * 1) in current memory layout... + /// and plaese note that N, K, M is a fixed place holder after considering + /// transpose. + /// For example, there is no case like (1 * K) X (1 * K) while + /// (1 * K) X (1 * M) can be a case + /// case1: (1 * K) X (K * 1) + if (M == 1 && N == 1) { + *rdata = dot_cl(data, mdata, K, context) + (*rdata); + } + /// case2: (M * K) X (K * 1) + else if (N == 1) { + transA ? sgemv_cl(data, mdata, rdata, dim2, dim1, lda, context) + : sgemv_cl(data, mdata, rdata, dim1, dim2, lda, context); + } + /// case3: (1 * K) X (K * N) = 1 * N = R + /// = R^T = (K * N) ^T * (1 * K) ^T = (N * K) * (K * 1) = (N * K) * (1 * K) + /// Effectively a translation of sgemv + else if (M == 1) { + transB = transB == CblasTrans ? CblasNoTrans : CblasTrans; + transB ? sgemv_cl(mdata, data, rdata, mdim2, mdim1, ldb, context) + : sgemv_cl(mdata, data, rdata, mdim1, mdim2, ldb, context); + } + /// case others: use sgemm + else { + // transA == false, transB == false + sgemm_cl(data, mdata, rdata, M, N, K, lda, ldb, ldc, context); + // todo: other condition implementations + } +#else + throw std::invalid_argument("Error: enable-fp16 is not enabled"); +#endif + } +} + +} // namespace nntrainer diff --git a/nntrainer/tensor/cl_operations/blas_kernel_interface.h b/nntrainer/tensor/cl_operations/blas_kernel_interface.h new file mode 100644 index 0000000000..41fdfd9242 --- /dev/null +++ b/nntrainer/tensor/cl_operations/blas_kernel_interface.h @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Debadri Samaddar + * + * @file blas_kernel_interface.h + * @date 5 June 2024 + * @brief Interface for blas OpenCL kernels + * @see https://github.com/nnstreamer/nntrainer + * @author Debadri Samaddar + * @bug No known bugs except for NYI items + * + */ + +#ifndef __BLAS_KERNEL_INTERFACE_H__ +#define __BLAS_KERNEL_INTERFACE_H__ + +#include +#include + +namespace nntrainer { + +/** + * @brief Process data and dimensions for OpenCL dot operation + * @param[in] input Tensor + * @param[in] m Tensor + * @param[in] result Tensor + * @param[in] RunLayerContext reference + * @param[in] trans bool + * @param[in] trans_m bool + */ +void dotCl(Tensor const &input, Tensor const &m, Tensor &result, + RunLayerContext &context, bool trans = false, bool trans_m = false); + +/** + * @brief Process data and dimensions for OpenCL dot operation + * @param[in] input Tensor + * @param[in] m Tensor + * @param[in] result Tensor + * @param[in] RunLayerContext reference + * @param[in] trans bool + * @param[in] trans_m bool + */ +void dotBatchedCl(Tensor const &input, Tensor const &m, Tensor &result, + RunLayerContext &context, bool trans = false, + bool trans_m = false); + +} // namespace nntrainer +#endif /* __BLAS_KERNEL_INTERFACE_H__ */ diff --git a/nntrainer/layers/cl_layers/blas_kernels.cpp b/nntrainer/tensor/cl_operations/blas_kernels.cpp similarity index 95% rename from nntrainer/layers/cl_layers/blas_kernels.cpp rename to nntrainer/tensor/cl_operations/blas_kernels.cpp index b994afd731..4c54a0b262 100644 --- a/nntrainer/layers/cl_layers/blas_kernels.cpp +++ b/nntrainer/tensor/cl_operations/blas_kernels.cpp @@ -17,11 +17,11 @@ namespace nntrainer { std::string sgemv_cl_kernel_ = R"(__kernel void sgemv_cl(const __global float* A, const __global float* X, - __global float* Y, unsigned int M, unsigned int lda) { + __global float* Y, unsigned int N, unsigned int lda) { unsigned int i; i = get_global_id(0); float y0 = 0.0f; - for (unsigned int j = 0; j < M; j++) + for (unsigned int j = 0; j < N; j++) y0 += A[i + j * lda] * X[j]; Y[i] = y0; @@ -76,9 +76,9 @@ void sgemv_cl(const float *matAdata, const float *vecXdata, float *vecYdata, opencl::Buffer inputA(context.context_inst_, dim1 * dim2 * sizeof(float), true, nullptr); - opencl::Buffer inputX(context.context_inst_, dim1_size, true, nullptr); + opencl::Buffer inputX(context.context_inst_, dim2_size, true, nullptr); - opencl::Buffer inOutY(context.context_inst_, dim2_size, true, nullptr); + opencl::Buffer inOutY(context.context_inst_, dim1_size, true, nullptr); result = inputA.WriteData(context.command_queue_inst_, matAdata); if (!result) { @@ -110,7 +110,7 @@ void sgemv_cl(const float *matAdata, const float *vecXdata, float *vecYdata, break; } - result = kernel_sgemv.SetKernelArguments(3, &dim1, sizeof(int)); + result = kernel_sgemv.SetKernelArguments(3, &dim2, sizeof(int)); if (!result) { break; } @@ -120,7 +120,7 @@ void sgemv_cl(const float *matAdata, const float *vecXdata, float *vecYdata, break; } - const int work_groups_count[3] = {(int)dim2, 1, 1}; + const int work_groups_count[3] = {(int)dim1, 1, 1}; const int work_group_size[3] = {32, 32, 1}; // test-value result = context.command_queue_inst_.DispatchCommand( diff --git a/nntrainer/layers/cl_layers/blas_kernels.h b/nntrainer/tensor/cl_operations/blas_kernels.h similarity index 98% rename from nntrainer/layers/cl_layers/blas_kernels.h rename to nntrainer/tensor/cl_operations/blas_kernels.h index 57c82ef8ad..d9f06490b0 100644 --- a/nntrainer/layers/cl_layers/blas_kernels.h +++ b/nntrainer/tensor/cl_operations/blas_kernels.h @@ -25,11 +25,8 @@ namespace nntrainer { * @brief declaring global kernel objects */ extern opencl::Kernel kernel_sgemv; -extern opencl::Kernel kernel_sgemv_fp16; extern opencl::Kernel kernel_sgemm; -extern opencl::Kernel kernel_sgemm_fp16; extern opencl::Kernel kernel_dot; -extern opencl::Kernel kernel_dot_fp16; /** * @brief sgemv computation : Y = A*X + Y @@ -45,20 +42,6 @@ void sgemv_cl(const float *matAdata, const float *vecXdata, float *vecYdata, unsigned int dim1, unsigned int dim2, unsigned int lda, RunLayerContext &context); -/** - * @brief fp16 sgemv computation : Y = A*X + Y - * @param[in] matAdata fp16 * for Matrix A - * @param[in] vecXdata fp16 * for Vector X - * @param[in] vecYdata fp16 * for Vector Y - * @param[in] dim1 number of A's columns - * @param[in] dim2 number of A's rows - * @param[in] lda number of X's columns - * @param[in] context RunLayerContext reference - */ -void sgemv_cl(const __fp16 *matAdata, const __fp16 *vecXdata, __fp16 *vecYdata, - unsigned int dim1, unsigned int dim2, unsigned int lda, - RunLayerContext &context); - /** * @brief dot computation : sum of all X * Y * @param[in] vecAdata float * for Vector A @@ -70,17 +53,6 @@ void sgemv_cl(const __fp16 *matAdata, const __fp16 *vecXdata, __fp16 *vecYdata, float dot_cl(const float *vecAdata, const float *vecXdata, unsigned int dim1, RunLayerContext &context); -/** - * @brief fp16 dot computation : sum of all X * Y - * @param[in] vecAdata fp16 * for Vector A - * @param[in] vecXdata fp16 * for Vector X - * @param[in] dim1 number of elements in both input vectors - * @param[in] context RunLayerContext reference - * @return fp16 dot product result - */ -__fp16 dot_cl(const __fp16 *vecAdata, const __fp16 *vecXdata, unsigned int dim1, - RunLayerContext &context); - /** * @brief sgemm computation : Y = op(A)*op(B) + C, * where op(X) is one of X or X**T @@ -99,6 +71,39 @@ void sgemm_cl(const float *A, const float *B, float *C, unsigned int M, unsigned int N, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc, RunLayerContext &context); +#ifdef ENABLE_FP16 +/** + * @brief declaring global fp16 kernel objects + */ +extern opencl::Kernel kernel_sgemv_fp16; +extern opencl::Kernel kernel_sgemm_fp16; +extern opencl::Kernel kernel_dot_fp16; + +/** + * @brief fp16 sgemv computation : Y = A*X + Y + * @param[in] matAdata fp16 * for Matrix A + * @param[in] vecXdata fp16 * for Vector X + * @param[in] vecYdata fp16 * for Vector Y + * @param[in] dim1 number of A's columns + * @param[in] dim2 number of A's rows + * @param[in] lda number of X's columns + * @param[in] context RunLayerContext reference + */ +void sgemv_cl(const __fp16 *matAdata, const __fp16 *vecXdata, __fp16 *vecYdata, + unsigned int dim1, unsigned int dim2, unsigned int lda, + RunLayerContext &context); + +/** + * @brief fp16 dot computation : sum of all X * Y + * @param[in] vecAdata fp16 * for Vector A + * @param[in] vecXdata fp16 * for Vector X + * @param[in] dim1 number of elements in both input vectors + * @param[in] context RunLayerContext reference + * @return fp16 dot product result + */ +__fp16 dot_cl(const __fp16 *vecAdata, const __fp16 *vecXdata, unsigned int dim1, + RunLayerContext &context); + /** * @brief fp16 sgemm computation : Y = op(A)*op(B) + C, * where op(X) is one of X or X**T @@ -116,6 +121,7 @@ void sgemm_cl(const float *A, const float *B, float *C, unsigned int M, void sgemm_cl(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, unsigned int N, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc, RunLayerContext &context); +#endif } // namespace nntrainer #endif /* __BLAS_KERNELS_H__ */ diff --git a/nntrainer/layers/cl_layers/blas_kernels_fp16.cpp b/nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp similarity index 95% rename from nntrainer/layers/cl_layers/blas_kernels_fp16.cpp rename to nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp index c85b0532ce..8948f0dc5c 100644 --- a/nntrainer/layers/cl_layers/blas_kernels_fp16.cpp +++ b/nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp @@ -20,11 +20,11 @@ std::string sgemv_cl_kernel_fp16_ = #pragma OPENCL EXTENSION cl_khr_fp16 : enable __kernel void sgemv_cl_fp16(const __global half* A, const __global half* X, - __global half* Y, unsigned int M, unsigned int lda) { + __global half* Y, unsigned int N, unsigned int lda) { unsigned int i; i = get_global_id(0); half y0 = 0.0f; - for (unsigned int j = 0; j < M; j++) + for (unsigned int j = 0; j < N; j++) y0 += A[i + j * lda] * X[j]; Y[i] = y0; @@ -86,9 +86,9 @@ void sgemv_cl(const __fp16 *matAdata, const __fp16 *vecXdata, __fp16 *vecYdata, opencl::Buffer inputA(context.context_inst_, dim1 * dim2 * sizeof(cl_half), true, nullptr); - opencl::Buffer inputX(context.context_inst_, dim1_size, true, nullptr); + opencl::Buffer inputX(context.context_inst_, dim2_size, true, nullptr); - opencl::Buffer inOutY(context.context_inst_, dim2_size, true, nullptr); + opencl::Buffer inOutY(context.context_inst_, dim1_size, true, nullptr); result = inputA.WriteData(context.command_queue_inst_, matAdata); if (!result) { @@ -120,7 +120,7 @@ void sgemv_cl(const __fp16 *matAdata, const __fp16 *vecXdata, __fp16 *vecYdata, break; } - result = kernel_sgemv_fp16.SetKernelArguments(3, &dim1, sizeof(int)); + result = kernel_sgemv_fp16.SetKernelArguments(3, &dim2, sizeof(int)); if (!result) { break; } @@ -130,7 +130,7 @@ void sgemv_cl(const __fp16 *matAdata, const __fp16 *vecXdata, __fp16 *vecYdata, break; } - const int work_groups_count[3] = {(int)dim2, 1, 1}; + const int work_groups_count[3] = {(int)dim1, 1, 1}; const int work_group_size[3] = {32, 32, 1}; // test-value result = context.command_queue_inst_.DispatchCommand( diff --git a/nntrainer/tensor/cl_operations/meson.build b/nntrainer/tensor/cl_operations/meson.build new file mode 100644 index 0000000000..4cff3e0c4a --- /dev/null +++ b/nntrainer/tensor/cl_operations/meson.build @@ -0,0 +1,20 @@ +cl_op_sources = [ + 'blas_kernels.cpp', + 'blas_kernel_interface.cpp', +] + +cl_op_headers = [ + 'blas_kernel_interface.h', +] + +if get_option('enable-fp16') + cl_op_sources += 'blas_kernels_fp16.cpp' +endif + +foreach s : cl_op_sources + nntrainer_sources += meson.current_source_dir() / s +endforeach + +foreach h : cl_op_headers + nntrainer_headers += meson.current_source_dir() / h +endforeach diff --git a/nntrainer/tensor/meson.build b/nntrainer/tensor/meson.build index fe4204cf85..202b730060 100644 --- a/nntrainer/tensor/meson.build +++ b/nntrainer/tensor/meson.build @@ -34,15 +34,6 @@ tensor_headers = [ 'blas_interface.h' ] -cl_sources = [ - 'cl_operations/cl_sgemv.cpp' -] - -cl_headers = [ - 'cl_operations/cl_interface.h' -] - - arch = host_machine.cpu_family() if get_option('enable-fp16') if arch == 'arm' @@ -71,8 +62,9 @@ if get_option('enable-fp16') endif if get_option('enable-opencl') - tensor_sources += cl_sources - tensor_headers += cl_headers + subdir('cl_operations') + nntrainer_inc += include_directories('cl_operations') + nntrainer_inc_abs += meson.current_source_dir() / 'cl_operations' endif foreach s : tensor_sources