Skip to content

Commit

Permalink
Make sure when dequant weights, take into account cublas' workspace m…
Browse files Browse the repository at this point in the history
…emory.
  • Loading branch information
liuliu committed Oct 16, 2024
1 parent 1d2beb4 commit bed628c
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions lib/nnc/cmd/blas/gpu/ccv_nnc_gemm_gpu_cublas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -308,17 +308,18 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
depalettize_w_params.reserved = 0;
w_data_size = ccv_nnc_tensor_data_size(depalettize_w_params);
}
const size_t cublas_size = ccv_nnc_cublas_workspace_size_in_bytes(inputs, input_size, outputs, output_size);
void* workspace = 0;
if (a_data_size + w_data_size > 0)
workspace = ccv_nnc_stream_context_get_workspace(stream_context, a_data_size + w_data_size, CCV_TENSOR_GPU_MEMORY);
workspace = ccv_nnc_stream_context_get_workspace(stream_context, cublas_size + a_data_size + w_data_size, CCV_TENSOR_GPU_MEMORY);
unsigned char* a_data = a->data.u8;
if (CCV_GET_DATA_TYPE(a->info.datatype) == CCV_QX)
{
ccv_nnc_tensor_param_t a_params = a->info;
const size_t count = ccv_nnc_tensor_count(a_params);
const int qbits = (a_params.datatype & 0xf00) >> 8;
const int number_in_blocks = a_params.reserved;
a_data = (unsigned char*)workspace;
a_data = (unsigned char*)workspace + cublas_size;
ccv_nnc_compat_depalettize(a->data.u8, a_datatype, ccv_nnc_tensor_data_size_without_padding(a_params), qbits, number_in_blocks, a_data, count, stream_context);
}
unsigned char* w_data = w->data.u8;
Expand All @@ -328,7 +329,7 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
const size_t count = ccv_nnc_tensor_count(w_params);
const int qbits = (w_params.datatype & 0xf00) >> 8;
const int number_in_blocks = w_params.reserved;
w_data = (unsigned char*)workspace + a_data_size;
w_data = (unsigned char*)workspace + cublas_size + a_data_size;
ccv_nnc_compat_depalettize(w->data.u8, w_datatype, ccv_nnc_tensor_data_size_without_padding(w_params), qbits, number_in_blocks, w_data, count, stream_context);
}
// Check if we can shortcut this and use dequantize_mul_mat_vec which will be faster for gmmv.
Expand Down Expand Up @@ -363,7 +364,7 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
return CCV_NNC_EXEC_SUCCESS;
}
cublasHandle_t cublas = ccv_nnc_stream_context_get_cublas(stream_context);
ccv_nnc_stream_context_set_cublas_workspace(cublas, stream_context, ccv_nnc_cublas_workspace_size_in_bytes(inputs, input_size, outputs, output_size));
ccv_nnc_stream_context_set_cublas_workspace(cublas, stream_context, cublas_size);
if (bias)
{
int bias_batch_size, bias_rows, bias_cols, bias_batch_inc, bias_rows_inc, bias_cols_inc;
Expand Down Expand Up @@ -622,7 +623,8 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
ccv_nnc_tensor_view_t* bias = output_size > 2 ? (ccv_nnc_tensor_view_t*)outputs[2] : 0;
assert(!bias || (bias->info.dim[1] == 0 || bias->info.dim[2] == 0 || bias->info.dim[3] == 0)); // It is a 2-d or 3-d array.
cublasHandle_t cublas = ccv_nnc_stream_context_get_cublas(stream_context);
ccv_nnc_stream_context_set_cublas_workspace(cublas, stream_context, ccv_nnc_cublas_workspace_size_in_bytes(inputs, input_size, outputs, output_size));
const size_t cublas_size = ccv_nnc_cublas_workspace_size_in_bytes(inputs, input_size, outputs, output_size);
ccv_nnc_stream_context_set_cublas_workspace(cublas, stream_context, cublas_size);
int g_batch_size, g_rows, g_cols, g_batch_inc, g_rows_inc, g_cols_inc;
const static int no_transpose[2] = {};
ccv_nnc_tensor_get_matrix_params(g->info, CCV_IS_TENSOR_VIEW(g) ? g->stride : 0, g->info.dim, no_transpose, &g_batch_size, &g_rows, &g_cols, &g_batch_inc, &g_rows_inc, &g_cols_inc);
Expand Down Expand Up @@ -678,7 +680,7 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
}
void* workspace = 0;
if (a_data_size + w_data_size > 0)
workspace = ccv_nnc_stream_context_get_workspace(stream_context, a_data_size + w_data_size, CCV_TENSOR_GPU_MEMORY);
workspace = ccv_nnc_stream_context_get_workspace(stream_context, cublas_size + a_data_size + w_data_size, CCV_TENSOR_GPU_MEMORY);
if (dw)
{
const ccv_nnc_tensor_view_t* a = (const ccv_nnc_tensor_view_t*)inputs[1];
Expand All @@ -689,7 +691,7 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
const size_t count = ccv_nnc_tensor_count(a_params);
const int qbits = (a_params.datatype & 0xf00) >> 8;
const int number_in_blocks = a_params.reserved;
a_data = (unsigned char*)workspace;
a_data = (unsigned char*)workspace + cublas_size;
ccv_nnc_compat_depalettize(a->data.u8, a_datatype, ccv_nnc_tensor_data_size_without_padding(a_params), qbits, number_in_blocks, a_data, count, stream_context);
}
const int transpose_a = ccv_nnc_is_matrix_transpose(a->info, cmd.info.blas.transpose_a);
Expand Down Expand Up @@ -744,7 +746,7 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
const size_t count = ccv_nnc_tensor_count(w_params);
const int qbits = (w_params.datatype & 0xf00) >> 8;
const int number_in_blocks = w_params.reserved;
w_data = (unsigned char*)workspace + a_data_size;
w_data = (unsigned char*)workspace + cublas_size + a_data_size;
ccv_nnc_compat_depalettize(w->data.u8, w_datatype, ccv_nnc_tensor_data_size_without_padding(w_params), qbits, number_in_blocks, w_data, count, stream_context);
}
const int transpose_w = ccv_nnc_is_matrix_transpose(w->info, cmd.info.blas.transpose_b);
Expand Down

0 comments on commit bed628c

Please sign in to comment.