Skip to content

Commit

Permalink
[bugfix/refactor] OpenCL buffer creation fix and optimization
Browse files Browse the repository at this point in the history
Used proper size while creating OpenCL buffers.
Optimized SGEMM kernel with 2D global work size.
Modified function docs.

Signed-off-by: Debadri Samaddar <[email protected]>
  • Loading branch information
s-debadri authored and jijoongmoon committed May 23, 2024
1 parent 6cfcb36 commit ddf8104
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 43 deletions.
65 changes: 26 additions & 39 deletions nntrainer/layers/cl_layers/blas_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,18 @@ std::string dot_cl_kernel_ =

std::string sgemm_cl_kernel_ =
R"(__kernel void sgemm_cl(const __global float* A, const __global float* B,
__global float* C, unsigned int M, unsigned int N, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc) {
__global float* C, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc) {
unsigned int m = get_global_id(0);
for (unsigned int n = 0; n < N; ++n) {
float c = 0.0f;
for (unsigned int k = 0; k < K; ++k) {
float a, b;
a = A[m * lda + k];
b = B[k * ldb + n];
c += a * b;
}
C[m * ldc + n] = c;
unsigned int n = get_global_id(1);
float c = 0.0f;
for (unsigned int k = 0; k < K; ++k) {
float a, b;
a = A[m * lda + k];
b = B[k * ldb + n];
c += a * b;
}
C[m * ldc + n] = c;
})";

/**
Expand All @@ -74,8 +73,8 @@ void sgemv_cl(const float *matAdata, const float *vecXdata, float *vecYdata,

size_t dim1_size = sizeof(float) * dim1;
size_t dim2_size = sizeof(float) * dim2;
opencl::Buffer inputA(context.context_inst_, dim1_size * dim2_size, true,
nullptr);
opencl::Buffer inputA(context.context_inst_, dim1 * dim2 * sizeof(float),
true, nullptr);

opencl::Buffer inputX(context.context_inst_, dim1_size, true, nullptr);

Expand Down Expand Up @@ -121,7 +120,7 @@ void sgemv_cl(const float *matAdata, const float *vecXdata, float *vecYdata,
break;
}

const int work_groups_count[3] = {(int)dim1, 1, 1};
const int work_groups_count[3] = {(int)dim2, 1, 1};
const int work_group_size[3] = {32, 32, 1}; // test-value

result = context.command_queue_inst_.DispatchCommand(
Expand All @@ -138,7 +137,7 @@ void sgemv_cl(const float *matAdata, const float *vecXdata, float *vecYdata,
} while (false);
}

float dot_cl(const float *matAdata, const float *vecXdata, unsigned int dim1,
float dot_cl(const float *vecAdata, const float *vecXdata, unsigned int dim1,
RunLayerContext &context) {

bool result = false;
Expand All @@ -161,7 +160,7 @@ float dot_cl(const float *matAdata, const float *vecXdata, unsigned int dim1,
opencl::Buffer dotResult(context.context_inst_, sizeof(float), true,
&cl_ret);

result = inputA.WriteData(context.command_queue_inst_, matAdata);
result = inputA.WriteData(context.command_queue_inst_, vecAdata);
if (!result) {
break;
}
Expand Down Expand Up @@ -223,17 +222,15 @@ void sgemm_cl(const float *A, const float *B, float *C, unsigned int M,
break;
}

size_t m_size = sizeof(float) * M;
size_t n_size = sizeof(float) * N;
size_t k_size = sizeof(float) * K;
opencl::Buffer inputA(context.context_inst_, m_size * k_size, true,
nullptr);
size_t m_k_size = M * K * sizeof(float);
size_t k_n_size = K * N * sizeof(float);
size_t m_n_size = M * N * sizeof(float);

opencl::Buffer inputA(context.context_inst_, m_k_size, true, nullptr);

opencl::Buffer inputB(context.context_inst_, k_size * n_size, true,
nullptr);
opencl::Buffer inputB(context.context_inst_, k_n_size, true, nullptr);

opencl::Buffer inOutC(context.context_inst_, m_size * n_size, true,
nullptr);
opencl::Buffer inOutC(context.context_inst_, m_n_size, true, nullptr);

result = inputA.WriteData(context.command_queue_inst_, A);
if (!result) {
Expand Down Expand Up @@ -265,37 +262,27 @@ void sgemm_cl(const float *A, const float *B, float *C, unsigned int M,
break;
}

result = kernel_sgemm.SetKernelArguments(3, &M, sizeof(int));
if (!result) {
break;
}

result = kernel_sgemm.SetKernelArguments(4, &N, sizeof(int));
if (!result) {
break;
}

result = kernel_sgemm.SetKernelArguments(5, &K, sizeof(int));
result = kernel_sgemm.SetKernelArguments(3, &K, sizeof(int));
if (!result) {
break;
}

result = kernel_sgemm.SetKernelArguments(6, &lda, sizeof(int));
result = kernel_sgemm.SetKernelArguments(4, &lda, sizeof(int));
if (!result) {
break;
}

result = kernel_sgemm.SetKernelArguments(7, &ldb, sizeof(int));
result = kernel_sgemm.SetKernelArguments(5, &ldb, sizeof(int));
if (!result) {
break;
}

result = kernel_sgemm.SetKernelArguments(8, &ldc, sizeof(int));
result = kernel_sgemm.SetKernelArguments(6, &ldc, sizeof(int));
if (!result) {
break;
}

const int work_groups_count[3] = {(int)M, 1, 1};
const int work_groups_count[3] = {(int)M, (int)N, 1};
const int work_group_size[3] = {32, 32, 1}; // test-value

result = context.command_queue_inst_.DispatchCommand(
Expand Down
8 changes: 4 additions & 4 deletions nntrainer/layers/cl_layers/blas_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ extern opencl::Kernel kernel_dot;
* @param[in] matAdata float * for Matrix A
* @param[in] vecXdata float * for Vector X
* @param[in] vecYdata float * for Vector Y
* @param[in] dim1 number of A's row
* @param[in] dim2 number of X's columns
* @param[in] dim1 number of A's columns
* @param[in] dim2 number of A's rows
* @param[in] lda number of X's columns
* @param[in] context RunLayerContext reference
*/
Expand All @@ -44,12 +44,12 @@ void sgemv_cl(const float *matAdata, const float *vecXdata, float *vecYdata,

/**
* @brief dot computation : sum of all X * Y
* @param[in] matAdata float * for Vector A
* @param[in] vecAdata float * for Vector A
* @param[in] vecXdata float * for Vector X
* @param[in] dim1 number of elements in both input vectors
* @param[in] context RunLayerContext reference
*/
float dot_cl(const float *matAdata, const float *vecXdata, unsigned int dim1,
float dot_cl(const float *vecAdata, const float *vecXdata, unsigned int dim1,
RunLayerContext &context);

/**
Expand Down

0 comments on commit ddf8104

Please sign in to comment.