Skip to content

Commit

Permalink
[GPU/OpenCL] RMSNorm Bug Fix - Index value of alpha corrected in kern…
Browse files Browse the repository at this point in the history
…el logic.

Updated RMSNorm with the new shared_ptr flow.
Replaced clCreateKernel with registerClKernel.

Self evaluation:

        Build test: [X]Passed [ ]Failed [ ]Skipped
	Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Niket Agarwal <[email protected]>
  • Loading branch information
niket-agarwal committed Oct 10, 2024
1 parent ccfc1de commit 3661ec8
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 80 deletions.
5 changes: 2 additions & 3 deletions nntrainer/cl_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ static void add_default_object(ClContext &cc) {
cc.registerFactory(nntrainer::createLayer<ReshapeLayerCl>,
ReshapeLayerCl::type, ml::train::LayerType::LAYER_RESHAPE);

// cc.registerFactory(nntrainer::createLayer<RMSNormLayerCl>,
// RMSNormLayerCl::type,
// ml::train::LayerType::LAYER_RMSNORM);
cc.registerFactory(nntrainer::createLayer<RMSNormLayerCl>,
RMSNormLayerCl::type, ml::train::LayerType::LAYER_RMSNORM);

cc.registerFactory(nntrainer::createLayer<ConcatLayerCl>, ConcatLayerCl::type,
ml::train::LayerType::LAYER_CONCAT);
Expand Down
2 changes: 1 addition & 1 deletion nntrainer/layers/cl_layers/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ cl_layer_sources = [
# 'addition_layer_cl.cpp',
'swiglu_cl.cpp',
'reshape_cl.cpp',
# 'rmsnorm_layer_cl.cpp',
'rmsnorm_layer_cl.cpp',
'concat_cl.cpp',
]

Expand Down
124 changes: 57 additions & 67 deletions nntrainer/layers/cl_layers/rmsnorm_layer_cl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ std::string rmsnorm_cl_kernel_fp16_ =
half rms_norm = sqrt(sum_squares + epsilon);
// Each work item processes all width elements for its specific n, h, c
for (int w = 0; w < W; ++w) {
output[index+w] = (input[index+w] / rms_norm) * alpha[index+w];
}
output[index+w] = (input[index+w] / rms_norm) * alpha[w];
}
}
)";

Expand Down Expand Up @@ -80,7 +80,7 @@ std::string rmsnorm_cl_kernel_ =
float rms_norm = sqrt(sum_squares + epsilon);
// Each work item processes all width elements for its specific n, h, c
for (int w = 0; w < W; ++w) {
output[index+w] = (input[index+w] / rms_norm) * alpha[index+w];
output[index+w] = (input[index+w] / rms_norm) * alpha[w];
}
}
)";
Expand Down Expand Up @@ -113,18 +113,21 @@ void RMSNormLayerCl::forwarding(RunLayerContext &context, bool training) {
Tensor &gamma = context.getWeight(wt_idx[RMSParams::gamma]);
auto &epsilon = std::get<props::Epsilon>(rmsnorm_props).get();
if (in.getDataType() == ml::train::TensorDim::DataType::FP32) {
rmsnormProcess(in, out, gamma, epsilon, context);
rmsnormProcess(in, out, gamma, epsilon);
} else {
rmsnormProcess_fp16(in, out, gamma, epsilon, context);
#ifdef ENABLE_FP16
rmsnormProcess_fp16(in, out, gamma, epsilon);
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
}
}

opencl::Kernel RMSNormLayerCl::kernel_rmsnorm;
opencl::Kernel RMSNormLayerCl::kernel_rmsnorm_fp16;

void RMSNormLayerCl::rmsnormProcess(Tensor const &input, Tensor &result,
Tensor const &gamma, const float epsilon,
RunLayerContext &context) {
Tensor const &gamma, const float epsilon) {
bool ret = false;
int dim1 = input.batch() * input.height() * input.width() * input.channel();
CREATE_IF_EMPTY_DIMS(result, input.batch(), input.channel(), input.height(),
Expand All @@ -133,86 +136,82 @@ void RMSNormLayerCl::rmsnormProcess(Tensor const &input, Tensor &result,
int c = input.channel();
int h = input.height();
int w = input.width();

do {
ret =
context.clCreateKernel(rmsnorm_cl_kernel_, context.LayerKernel::RMSNORM,
RMSNormLayerCl::kernel_rmsnorm);
if (!ret) {
ClContext::SharedPtrClKernel kernel_rmsnorm_ptr =
cl_context_ref.registerClKernel(rmsnorm_cl_kernel_, "rmsnorm_cl");
if (!kernel_rmsnorm_ptr) {
break;
}

opencl::Buffer inputbuf(context.context_inst_, dim1 * sizeof(float), true,
nullptr);
opencl::Buffer inputbuf(cl_context_ref.context_inst_, dim1 * sizeof(float),
true, nullptr);

opencl::Buffer gammabuf(context.context_inst_,
opencl::Buffer gammabuf(cl_context_ref.context_inst_,
input.width() * sizeof(float), true, nullptr);
opencl::Buffer resultbuf(context.context_inst_, dim1 * sizeof(float), true,
nullptr);
opencl::Buffer resultbuf(cl_context_ref.context_inst_, dim1 * sizeof(float),
true, nullptr);

const float *data = input.getData();
float *rdata = result.getData();
const float *gdata = gamma.getData();
ret = inputbuf.WriteData(context.command_queue_inst_, data);
ret = inputbuf.WriteData(cl_context_ref.command_queue_inst_, data);
if (!ret) {
break;
}

ret = gammabuf.WriteData(context.command_queue_inst_, gdata);
ret = gammabuf.WriteData(cl_context_ref.command_queue_inst_, gdata);
if (!ret) {
break;
}
ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(0, &inputbuf,
sizeof(cl_mem));
ret = kernel_rmsnorm_ptr->SetKernelArguments(0, &inputbuf, sizeof(cl_mem));
if (!ret) {
break;
}

ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(1, &resultbuf,
sizeof(cl_mem));
ret = kernel_rmsnorm_ptr->SetKernelArguments(1, &resultbuf, sizeof(cl_mem));
if (!ret) {
break;
}

ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(2, &gammabuf,
sizeof(cl_mem));
ret = kernel_rmsnorm_ptr->SetKernelArguments(2, &gammabuf, sizeof(cl_mem));
if (!ret) {
break;
}
ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(4, &b, sizeof(int));
ret = kernel_rmsnorm_ptr->SetKernelArguments(4, &b, sizeof(int));

if (!ret) {
break;
}

ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(3, &epsilon,
sizeof(float));
ret = kernel_rmsnorm_ptr->SetKernelArguments(3, &epsilon, sizeof(float));
if (!ret) {
break;
}

ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(5, &c, sizeof(int));
ret = kernel_rmsnorm_ptr->SetKernelArguments(5, &c, sizeof(int));
if (!ret) {
break;
}

ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(6, &h, sizeof(int));
ret = kernel_rmsnorm_ptr->SetKernelArguments(6, &h, sizeof(int));
if (!ret) {
break;
}
ret = RMSNormLayerCl::kernel_rmsnorm.SetKernelArguments(7, &w, sizeof(int));
ret = kernel_rmsnorm_ptr->SetKernelArguments(7, &w, sizeof(int));
if (!ret) {
break;
}
const int work_groups_count[3] = {b * c, h, 1};
const int work_group_size[3] = {32, 32, 1}; // test-value

ret = context.command_queue_inst_.DispatchCommand(
RMSNormLayerCl::kernel_rmsnorm, work_groups_count, work_group_size);
ret = cl_context_ref.command_queue_inst_.DispatchCommand(
kernel_rmsnorm_ptr, work_groups_count, work_group_size);
if (!ret) {
break;
}

ret = resultbuf.ReadData(context.command_queue_inst_, rdata);
ret = resultbuf.ReadData(cl_context_ref.command_queue_inst_, rdata);
if (!ret) {
break;
}
Expand All @@ -222,8 +221,7 @@ void RMSNormLayerCl::rmsnormProcess(Tensor const &input, Tensor &result,

void RMSNormLayerCl::rmsnormProcess_fp16(Tensor const &input, Tensor &result,
Tensor const &gamma,
const float epsilon,
RunLayerContext &context) {
const float epsilon) {

bool ret = false;
int dim1 = input.batch() * input.height() * input.width() * input.channel();
Expand All @@ -234,85 +232,77 @@ void RMSNormLayerCl::rmsnormProcess_fp16(Tensor const &input, Tensor &result,
int h = input.height();
int w = input.width();
do {
ret = context.clCreateKernel(rmsnorm_cl_kernel_fp16_,
context.LayerKernel::RMSNORM_FP16,
RMSNormLayerCl::kernel_rmsnorm_fp16);
if (!ret) {
ClContext::SharedPtrClKernel kernel_rmsnorm_ptr =
cl_context_ref.registerClKernel(rmsnorm_cl_kernel_fp16_,
"rmsnorm_cl_fp16");
if (!kernel_rmsnorm_ptr) {
break;
}
opencl::Buffer inputbuf(context.context_inst_, dim1 * sizeof(cl_half), true,
nullptr);
opencl::Buffer inputbuf(cl_context_ref.context_inst_,
dim1 * sizeof(cl_half), true, nullptr);

opencl::Buffer gammabuf(context.context_inst_,
opencl::Buffer gammabuf(cl_context_ref.context_inst_,
input.width() * sizeof(cl_half), true, nullptr);
opencl::Buffer resultbuf(context.context_inst_, dim1 * sizeof(cl_half),
true, nullptr);
opencl::Buffer resultbuf(cl_context_ref.context_inst_,
dim1 * sizeof(cl_half), true, nullptr);

const __fp16 *data = input.getData<__fp16>();
__fp16 *rdata = result.getData<__fp16>();
const __fp16 *gdata = gamma.getData<__fp16>();
ret = inputbuf.WriteData(context.command_queue_inst_, data);
ret = inputbuf.WriteData(cl_context_ref.command_queue_inst_, data);
if (!ret) {
break;
}

ret = gammabuf.WriteData(context.command_queue_inst_, gdata);
ret = gammabuf.WriteData(cl_context_ref.command_queue_inst_, gdata);
if (!ret) {
break;
}
ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(
0, &inputbuf, sizeof(cl_mem));
ret = kernel_rmsnorm_ptr->SetKernelArguments(0, &inputbuf, sizeof(cl_mem));
if (!ret) {
break;
}
ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(
1, &resultbuf, sizeof(cl_mem));
ret = kernel_rmsnorm_ptr->SetKernelArguments(1, &resultbuf, sizeof(cl_mem));
if (!ret) {
break;
}

ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(
2, &gammabuf, sizeof(cl_mem));
ret = kernel_rmsnorm_ptr->SetKernelArguments(2, &gammabuf, sizeof(cl_mem));
if (!ret) {
break;
}
ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(4, &b,
sizeof(int));
ret = kernel_rmsnorm_ptr->SetKernelArguments(4, &b, sizeof(int));
if (!ret) {
break;
}

ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(
3, &epsilon, sizeof(cl_half));
ret = kernel_rmsnorm_ptr->SetKernelArguments(3, &epsilon, sizeof(cl_half));
if (!ret) {
break;
}

ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(5, &c,
sizeof(int));
ret = kernel_rmsnorm_ptr->SetKernelArguments(5, &c, sizeof(int));
if (!ret) {
break;
}
ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(6, &h,
sizeof(int));
ret = kernel_rmsnorm_ptr->SetKernelArguments(6, &h, sizeof(int));
if (!ret) {
break;
}
ret = RMSNormLayerCl::kernel_rmsnorm_fp16.SetKernelArguments(7, &w,
sizeof(int));
ret = kernel_rmsnorm_ptr->SetKernelArguments(7, &w, sizeof(int));
if (!ret) {
break;
}
const int work_groups_count[3] = {b * c, h, 1};
const int work_group_size[3] = {32, 32, 1}; // test-value

ret = context.command_queue_inst_.DispatchCommand(
RMSNormLayerCl::kernel_rmsnorm_fp16, work_groups_count, work_group_size);
ret = cl_context_ref.command_queue_inst_.DispatchCommand(
kernel_rmsnorm_ptr, work_groups_count, work_group_size);
if (!ret) {
break;
}

ret = resultbuf.ReadData(context.command_queue_inst_, rdata);
ret = resultbuf.ReadData(cl_context_ref.command_queue_inst_, rdata);
if (!ret) {
break;
}
Expand Down Expand Up @@ -347,9 +337,9 @@ void RMSNormLayerCl::incremental_forwarding(nntrainer::RunLayerContext &context,
auto &epsilon = std::get<props::Epsilon>(rmsnorm_props).get();

if (in_step.getDataType() == ml::train::TensorDim::DataType::FP32) {
rmsnormProcess(in, out, gamma, epsilon, context);
rmsnormProcess(in, out, gamma, epsilon);
} else {
rmsnormProcess_fp16(in, out, gamma, epsilon, context);
rmsnormProcess_fp16(in, out, gamma, epsilon);
}
}

Expand Down
22 changes: 13 additions & 9 deletions nntrainer/layers/cl_layers/rmsnorm_layer_cl.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <layer_impl.h>
#include <nntrainer_log.h>

#include <cl_context.h>
#include <opencl_buffer.h>
#include <opencl_kernel.h>

Expand Down Expand Up @@ -49,7 +50,12 @@ class RMS_NORM_GAMMA_INIT_GPU final
* @class RMSNormLayer
* @brief RMS Norm layer
*/

class RMSNormLayerCl : public LayerImpl {

private:
inline static ClContext cl_context_ref;

public:
/**
* @brief Constructor of RMS Norm Layer
Expand Down Expand Up @@ -84,9 +90,9 @@ class RMSNormLayerCl : public LayerImpl {
void forwarding(RunLayerContext &context, bool training) override;

/**
* @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
* int from, unsigned int to, bool training)
*/
* @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
* int from, unsigned int to, bool training)
*/
void incremental_forwarding(RunLayerContext &context, unsigned int from,
unsigned int to, bool training) override;

Expand Down Expand Up @@ -121,24 +127,22 @@ class RMSNormLayerCl : public LayerImpl {
* @param[in] result Tensor
* @param[in] gamma Tensor
* @param[in] epsilon float
* @param[in] RunLayerContext reference
*/

void rmsnormProcess(Tensor const &input, Tensor &result, Tensor const &gamma,
const float epsilon, RunLayerContext &context);

const float epsilon);
#ifdef ENABLE_FP16
/**
* @brief Process data and dimensions for FP16 rms norm operation
* @param[in] input Tensor
* @param[in] result Tensor
* @param[in] gamma Tensor
* @param[in] epsilon float
* @param[in] RunLayerContext reference
*/

void rmsnormProcess_fp16(Tensor const &input, Tensor &result,
Tensor const &gamma, const float epsilon,
RunLayerContext &context);
Tensor const &gamma, const float epsilon);
#endif
/**
* @copydoc Layer::supportBackwarding()
*/
Expand Down

0 comments on commit 3661ec8

Please sign in to comment.