From 3661ec8b09013245f4fb7624a9456527b96fe465 Mon Sep 17 00:00:00 2001 From: Niket Agarwal Date: Thu, 10 Oct 2024 16:30:58 +0530 Subject: [PATCH] [GPU/OpenCL] RMSNorm Bug Fix - Index value of alpha corrected in kernel logic. Updated RMSNorm with the new shared_ptr flow. Replaced clCreateKernel with registerClKernel. Self evaluation: Build test: [X]Passed [ ]Failed [ ]Skipped Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Niket Agarwal --- nntrainer/cl_context.cpp | 5 +- nntrainer/layers/cl_layers/meson.build | 2 +- .../layers/cl_layers/rmsnorm_layer_cl.cpp | 124 ++++++++---------- nntrainer/layers/cl_layers/rmsnorm_layer_cl.h | 22 ++-- 4 files changed, 73 insertions(+), 80 deletions(-) diff --git a/nntrainer/cl_context.cpp b/nntrainer/cl_context.cpp index df48a83d17..a050a4f0f7 100644 --- a/nntrainer/cl_context.cpp +++ b/nntrainer/cl_context.cpp @@ -45,9 +45,8 @@ static void add_default_object(ClContext &cc) { cc.registerFactory(nntrainer::createLayer, ReshapeLayerCl::type, ml::train::LayerType::LAYER_RESHAPE); - // cc.registerFactory(nntrainer::createLayer, - // RMSNormLayerCl::type, - // ml::train::LayerType::LAYER_RMSNORM); + cc.registerFactory(nntrainer::createLayer, + RMSNormLayerCl::type, ml::train::LayerType::LAYER_RMSNORM); cc.registerFactory(nntrainer::createLayer, ConcatLayerCl::type, ml::train::LayerType::LAYER_CONCAT); diff --git a/nntrainer/layers/cl_layers/meson.build b/nntrainer/layers/cl_layers/meson.build index 8aed7f2a79..e229f44069 100644 --- a/nntrainer/layers/cl_layers/meson.build +++ b/nntrainer/layers/cl_layers/meson.build @@ -3,7 +3,7 @@ cl_layer_sources = [ # 'addition_layer_cl.cpp', 'swiglu_cl.cpp', 'reshape_cl.cpp', - # 'rmsnorm_layer_cl.cpp', + 'rmsnorm_layer_cl.cpp', 'concat_cl.cpp', ] diff --git a/nntrainer/layers/cl_layers/rmsnorm_layer_cl.cpp b/nntrainer/layers/cl_layers/rmsnorm_layer_cl.cpp index 96e9a53069..179b89fa8a 100644 --- a/nntrainer/layers/cl_layers/rmsnorm_layer_cl.cpp +++ b/nntrainer/layers/cl_layers/rmsnorm_layer_cl.cpp @@ -50,8 +50,8 @@ std::string rmsnorm_cl_kernel_fp16_ = half rms_norm = sqrt(sum_squares + epsilon); // Each work item processes all width elements for its specific n, h, c for (int w = 0; w < W; ++w) { - output[index+w] = (input[index+w] / rms_norm) * alpha[index+w]; - } + output[index+w] = (input[index+w] / rms_norm) * alpha[w]; + } } )"; @@ -80,7 +80,7 @@ std::string rmsnorm_cl_kernel_ = float rms_norm = sqrt(sum_squares + epsilon); // Each work item processes all width elements for its specific n, h, c for (int w = 0; w < W; ++w) { - output[index+w] = (input[index+w] / rms_norm) * alpha[index+w]; + output[index+w] = (input[index+w] / rms_norm) * alpha[w]; } } )"; @@ -113,9 +113,13 @@ void RMSNormLayerCl::forwarding(RunLayerContext &context, bool training) { Tensor &gamma = context.getWeight(wt_idx[RMSParams::gamma]); auto &epsilon = std::get(rmsnorm_props).get(); if (in.getDataType() == ml::train::TensorDim::DataType::FP32) { - rmsnormProcess(in, out, gamma, epsilon, context); + rmsnormProcess(in, out, gamma, epsilon); } else { - rmsnormProcess_fp16(in, out, gamma, epsilon, context); +#ifdef ENABLE_FP16 + rmsnormProcess_fp16(in, out, gamma, epsilon); +#else + throw std::invalid_argument("Error: enable-fp16 is not enabled"); +#endif } } @@ -123,8 +127,7 @@ opencl::Kernel RMSNormLayerCl::kernel_rmsnorm; opencl::Kernel RMSNormLayerCl::kernel_rmsnorm_fp16; void RMSNormLayerCl::rmsnormProcess(Tensor const &input, Tensor &result, - Tensor const &gamma, const float epsilon, - RunLayerContext &context) { + Tensor const &gamma, const float epsilon) { bool ret = false; int dim1 = input.batch() * input.height() * input.width() * input.channel(); CREATE_IF_EMPTY_DIMS(result, input.batch(), input.channel(), input.height(), @@ -133,86 +136,82 @@ void RMSNormLayerCl::rmsnormProcess(Tensor const &input, Tensor &result, int c = input.channel(); int h = input.height(); int w = input.width(); + do { - ret = - context.clCreateKernel(rmsnorm_cl_kernel_, context.LayerKernel::RMSNORM, - RMSNormLayerCl::kernel_rmsnorm); - if (!ret) { + ClContext::SharedPtrClKernel kernel_rmsnorm_ptr = + cl_context_ref.registerClKernel(rmsnorm_cl_kernel_, "rmsnorm_cl"); + if (!kernel_rmsnorm_ptr) { break; } - opencl::Buffer inputbuf(context.context_inst_, dim1 * sizeof(float), true, - nullptr); + opencl::Buffer inputbuf(cl_context_ref.context_inst_, dim1 * sizeof(float), + true, nullptr); - opencl::Buffer gammabuf(context.context_inst_, + opencl::Buffer gammabuf(cl_context_ref.context_inst_, input.width() * sizeof(float), true, nullptr); - opencl::Buffer resultbuf(context.context_inst_, dim1 * sizeof(float), true, - nullptr); + opencl::Buffer resultbuf(cl_context_ref.context_inst_, dim1 * sizeof(float), + true, nullptr); const float *data = input.getData(); float *rdata = result.getData(); const float *gdata = gamma.getData(); - ret = inputbuf.WriteData(context.command_queue_inst_, data); + ret = inputbuf.WriteData(cl_context_ref.command_queue_inst_, data); if (!ret) { break; } - ret = gammabuf.WriteData(context.command_queue_inst_, gdata); + ret = gammabuf.WriteData(cl_context_ref.command_queue_inst_, gdata); if (!ret) { break; } - ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(0, &inputbuf, - sizeof(cl_mem)); + ret = kernel_rmsnorm_ptr->SetKernelArguments(0, &inputbuf, sizeof(cl_mem)); if (!ret) { break; } - ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(1, &resultbuf, - sizeof(cl_mem)); + ret = kernel_rmsnorm_ptr->SetKernelArguments(1, &resultbuf, sizeof(cl_mem)); if (!ret) { break; } - ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(2, &gammabuf, - sizeof(cl_mem)); + ret = kernel_rmsnorm_ptr->SetKernelArguments(2, &gammabuf, sizeof(cl_mem)); if (!ret) { break; } - ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(4, &b, sizeof(int)); + ret = kernel_rmsnorm_ptr->SetKernelArguments(4, &b, sizeof(int)); if (!ret) { break; } - ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(3, &epsilon, - sizeof(float)); + ret = kernel_rmsnorm_ptr->SetKernelArguments(3, &epsilon, sizeof(float)); if (!ret) { break; } - ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(5, &c, sizeof(int)); + ret = kernel_rmsnorm_ptr->SetKernelArguments(5, &c, sizeof(int)); if (!ret) { break; } - ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(6, &h, sizeof(int)); + ret = kernel_rmsnorm_ptr->SetKernelArguments(6, &h, sizeof(int)); if (!ret) { break; } - ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(7, &w, sizeof(int)); + ret = kernel_rmsnorm_ptr->SetKernelArguments(7, &w, sizeof(int)); if (!ret) { break; } const int work_groups_count[3] = {b * c, h, 1}; const int work_group_size[3] = {32, 32, 1}; // test-value - ret = context.command_queue_inst_.DispatchCommand( - RMSNormLayerCl::kernel_rmsnorm, work_groups_count, work_group_size); + ret = cl_context_ref.command_queue_inst_.DispatchCommand( + kernel_rmsnorm_ptr, work_groups_count, work_group_size); if (!ret) { break; } - ret = resultbuf.ReadData(context.command_queue_inst_, rdata); + ret = resultbuf.ReadData(cl_context_ref.command_queue_inst_, rdata); if (!ret) { break; } @@ -222,8 +221,7 @@ void RMSNormLayerCl::rmsnormProcess(Tensor const &input, Tensor &result, void RMSNormLayerCl::rmsnormProcess_fp16(Tensor const &input, Tensor &result, Tensor const &gamma, - const float epsilon, - RunLayerContext &context) { + const float epsilon) { bool ret = false; int dim1 = input.batch() * input.height() * input.width() * input.channel(); @@ -234,85 +232,77 @@ void RMSNormLayerCl::rmsnormProcess_fp16(Tensor const &input, Tensor &result, int h = input.height(); int w = input.width(); do { - ret = context.clCreateKernel(rmsnorm_cl_kernel_fp16_, - context.LayerKernel::RMSNORM_FP16, - RMSNormLayerCl::kernel_rmsnorm_fp16); - if (!ret) { + ClContext::SharedPtrClKernel kernel_rmsnorm_ptr = + cl_context_ref.registerClKernel(rmsnorm_cl_kernel_fp16_, + "rmsnorm_cl_fp16"); + if (!kernel_rmsnorm_ptr) { break; } - opencl::Buffer inputbuf(context.context_inst_, dim1 * sizeof(cl_half), true, - nullptr); + opencl::Buffer inputbuf(cl_context_ref.context_inst_, + dim1 * sizeof(cl_half), true, nullptr); - opencl::Buffer gammabuf(context.context_inst_, + opencl::Buffer gammabuf(cl_context_ref.context_inst_, input.width() * sizeof(cl_half), true, nullptr); - opencl::Buffer resultbuf(context.context_inst_, dim1 * sizeof(cl_half), - true, nullptr); + opencl::Buffer resultbuf(cl_context_ref.context_inst_, + dim1 * sizeof(cl_half), true, nullptr); const __fp16 *data = input.getData<__fp16>(); __fp16 *rdata = result.getData<__fp16>(); const __fp16 *gdata = gamma.getData<__fp16>(); - ret = inputbuf.WriteData(context.command_queue_inst_, data); + ret = inputbuf.WriteData(cl_context_ref.command_queue_inst_, data); if (!ret) { break; } - ret = gammabuf.WriteData(context.command_queue_inst_, gdata); + ret = gammabuf.WriteData(cl_context_ref.command_queue_inst_, gdata); if (!ret) { break; } - ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments( - 0, &inputbuf, sizeof(cl_mem)); + ret = kernel_rmsnorm_ptr->SetKernelArguments(0, &inputbuf, sizeof(cl_mem)); if (!ret) { break; } - ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments( - 1, &resultbuf, sizeof(cl_mem)); + ret = kernel_rmsnorm_ptr->SetKernelArguments(1, &resultbuf, sizeof(cl_mem)); if (!ret) { break; } - ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments( - 2, &gammabuf, sizeof(cl_mem)); + ret = kernel_rmsnorm_ptr->SetKernelArguments(2, &gammabuf, sizeof(cl_mem)); if (!ret) { break; } - ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(4, &b, - sizeof(int)); + ret = kernel_rmsnorm_ptr->SetKernelArguments(4, &b, sizeof(int)); if (!ret) { break; } - ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments( - 3, &epsilon, sizeof(cl_half)); + ret = kernel_rmsnorm_ptr->SetKernelArguments(3, &epsilon, sizeof(cl_half)); if (!ret) { break; } - ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(5, &c, - sizeof(int)); + ret = kernel_rmsnorm_ptr->SetKernelArguments(5, &c, sizeof(int)); if (!ret) { break; } - ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(6, &h, - sizeof(int)); + ret = kernel_rmsnorm_ptr->SetKernelArguments(6, &h, sizeof(int)); if (!ret) { break; } - ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(7, &w, - sizeof(int)); + ret = kernel_rmsnorm_ptr->SetKernelArguments(7, &w, sizeof(int)); if (!ret) { break; } const int work_groups_count[3] = {b * c, h, 1}; const int work_group_size[3] = {32, 32, 1}; // test-value - ret = context.command_queue_inst_.DispatchCommand( - RMSNormLayerCl::kernel_rmsnorm_fp16, work_groups_count, work_group_size); + ret = cl_context_ref.command_queue_inst_.DispatchCommand( + kernel_rmsnorm_ptr, work_groups_count, work_group_size); if (!ret) { break; } - ret = resultbuf.ReadData(context.command_queue_inst_, rdata); + ret = resultbuf.ReadData(cl_context_ref.command_queue_inst_, rdata); if (!ret) { break; } @@ -347,9 +337,9 @@ void RMSNormLayerCl::incremental_forwarding(nntrainer::RunLayerContext &context, auto &epsilon = std::get(rmsnorm_props).get(); if (in_step.getDataType() == ml::train::TensorDim::DataType::FP32) { - rmsnormProcess(in, out, gamma, epsilon, context); + rmsnormProcess(in, out, gamma, epsilon); } else { - rmsnormProcess_fp16(in, out, gamma, epsilon, context); + rmsnormProcess_fp16(in, out, gamma, epsilon); } } diff --git a/nntrainer/layers/cl_layers/rmsnorm_layer_cl.h b/nntrainer/layers/cl_layers/rmsnorm_layer_cl.h index 4b34729409..43f942ea1e 100644 --- a/nntrainer/layers/cl_layers/rmsnorm_layer_cl.h +++ b/nntrainer/layers/cl_layers/rmsnorm_layer_cl.h @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -49,7 +50,12 @@ class RMS_NORM_GAMMA_INIT_GPU final * @class RMSNormLayer * @brief RMS Norm layer */ + class RMSNormLayerCl : public LayerImpl { + +private: + inline static ClContext cl_context_ref; + public: /** * @brief Constructor of RMS Norm Layer @@ -84,9 +90,9 @@ class RMSNormLayerCl : public LayerImpl { void forwarding(RunLayerContext &context, bool training) override; /** - * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned - * int from, unsigned int to, bool training) - */ + * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned + * int from, unsigned int to, bool training) + */ void incremental_forwarding(RunLayerContext &context, unsigned int from, unsigned int to, bool training) override; @@ -121,24 +127,22 @@ class RMSNormLayerCl : public LayerImpl { * @param[in] result Tensor * @param[in] gamma Tensor * @param[in] epsilon float - * @param[in] RunLayerContext reference */ void rmsnormProcess(Tensor const &input, Tensor &result, Tensor const &gamma, - const float epsilon, RunLayerContext &context); - + const float epsilon); +#ifdef ENABLE_FP16 /** * @brief Process data and dimensions for FP16 rms norm operation * @param[in] input Tensor * @param[in] result Tensor * @param[in] gamma Tensor * @param[in] epsilon float - * @param[in] RunLayerContext reference */ void rmsnormProcess_fp16(Tensor const &input, Tensor &result, - Tensor const &gamma, const float epsilon, - RunLayerContext &context); + Tensor const &gamma, const float epsilon); +#endif /** * @copydoc Layer::supportBackwarding() */