-
Notifications
You must be signed in to change notification settings - Fork 77
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[refactor] Moved blas_kernels to tensor directory
Moved common OpenCL blas kernels to tensor directory. Added pre processing functions as common that can be re-used. Signed-off-by: Debadri Samaddar <[email protected]>
- Loading branch information
Showing
10 changed files
with
310 additions
and
182 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,7 @@ | ||
cl_layer_sources = [ | ||
'fc_layer_cl.cpp', | ||
'blas_kernels.cpp', | ||
] | ||
|
||
if get_option('enable-fp16') | ||
cl_layer_sources += 'blas_kernels_fp16.cpp' | ||
endif | ||
|
||
foreach s : cl_layer_sources | ||
nntrainer_sources += meson.current_source_dir() / s | ||
endforeach |
189 changes: 189 additions & 0 deletions
189
nntrainer/tensor/cl_operations/blas_kernel_interface.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
// SPDX-License-Identifier: Apache-2.0 | ||
/** | ||
* Copyright (C) 2024 Debadri Samaddar <[email protected]> | ||
* | ||
* @file blas_kernel_interface.cpp | ||
* @date 5 June 2024 | ||
* @brief Interface for blas OpenCL kernels | ||
* @see https://github.com/nnstreamer/nntrainer | ||
* @author Debadri Samaddar <[email protected]> | ||
* @bug No known bugs except for NYI items | ||
* | ||
*/ | ||
|
||
#include <blas_kernel_interface.h> | ||
#include <blas_kernels.h> | ||
|
||
namespace nntrainer { | ||
void dotBatchedCl(Tensor const &input, Tensor const &m, Tensor &result, | ||
RunLayerContext &context, bool trans, bool trans_m) { | ||
if (!result.isAllocated()) | ||
throw std::invalid_argument( | ||
"Output tensor must be preallocated for dotBatched operation"); | ||
for (unsigned int b = 0; b < input.batch(); b++) { | ||
/** @todo try using transpose to speedup the operation */ | ||
const Tensor this_b = input.getBatchSlice(b, 1); | ||
Tensor m_b = m.getBatchSlice(b, 1); | ||
Tensor result_b = result.getBatchSlice(b, 1); | ||
|
||
dotCl(this_b, m_b, result_b, context, trans, trans_m); | ||
} | ||
} | ||
|
||
void dotCl(Tensor const &input, Tensor const &m, Tensor &result, | ||
RunLayerContext &context, bool trans, bool trans_m) { | ||
unsigned int dim1, dim2, mdim1, mdim2; | ||
if (input.getFormat() == Tformat::NHWC) { | ||
dim1 = input.batch() * input.height() * input.width(); | ||
dim2 = input.channel(); | ||
mdim1 = m.batch() * m.height() * m.width(); | ||
mdim2 = m.channel(); | ||
} else { | ||
dim1 = input.batch() * input.channel() * input.height(); | ||
dim2 = input.width(); | ||
mdim1 = m.batch() * m.channel() * m.height(); | ||
mdim2 = m.width(); | ||
} | ||
|
||
unsigned int M, N, K, lda, ldb, ldc; | ||
|
||
if (!trans && !trans_m) { | ||
if (dim2 != mdim1) | ||
throw std::runtime_error( | ||
"Error: incompatible dimensions for dot product"); | ||
K = mdim1; /** == dim2 */ | ||
N = mdim2; | ||
M = dim1; | ||
if (input.getFormat() == Tformat::NHWC) { | ||
CREATE_IF_EMPTY_DIMS(result, input.batch(), N, input.height(), | ||
input.width(), | ||
input.getTensorType()); // NHWC Result Tensor | ||
} else { | ||
CREATE_IF_EMPTY_DIMS(result, input.batch(), input.channel(), | ||
input.height(), N, input.getTensorType()); | ||
} | ||
} else if (!trans && trans_m) { | ||
if (dim2 != mdim2) | ||
throw std::runtime_error( | ||
"Error: incompatible dimensions for dot product"); | ||
K = mdim2; /** == dim2 */ | ||
N = mdim1; | ||
M = dim1; | ||
if (input.getFormat() == Tformat::NHWC) { | ||
CREATE_IF_EMPTY_DIMS(result, input.batch(), N, input.height(), | ||
input.width(), input.getTensorType()); | ||
} else { | ||
CREATE_IF_EMPTY_DIMS(result, input.batch(), input.channel(), | ||
input.height(), N, input.getTensorType()); | ||
} | ||
} else if (trans && !trans_m) { | ||
if (dim1 != mdim1) | ||
throw std::runtime_error( | ||
"Error: incompatible dimensions for dot product"); | ||
K = mdim1; /** == dim1 */ | ||
N = mdim2; | ||
M = dim2; | ||
if (input.getFormat() == Tformat::NHWC) { | ||
CREATE_IF_EMPTY_DIMS(result, 1, N, M, 1, input.getTensorType()); | ||
} else { | ||
CREATE_IF_EMPTY_DIMS(result, 1, 1, M, N, input.getTensorType()); | ||
} | ||
} else { | ||
if (dim1 != mdim2) | ||
throw std::runtime_error( | ||
"Error: incompatible dimensions for dot product"); | ||
K = mdim2; /** == dim1 */ | ||
N = mdim1; | ||
M = dim2; | ||
if (input.getFormat() == Tformat::NHWC) { | ||
CREATE_IF_EMPTY_DIMS(result, 1, N, M, 1, input.getTensorType()); | ||
} else { | ||
CREATE_IF_EMPTY_DIMS(result, 1, 1, M, N, input.getTensorType()); | ||
} | ||
} | ||
|
||
lda = dim2; | ||
ldb = mdim2; | ||
ldc = | ||
(input.getFormat() == Tformat::NHWC) ? result.channel() : result.width(); | ||
|
||
if (input.getDataType() == ml::train::TensorDim::DataType::FP32) { | ||
const float *data = input.getData(); | ||
const float *mdata = m.getData(); | ||
float *rdata = result.getData(); | ||
enum CBLAS_TRANSPOSE transA = trans ? CblasTrans : CblasNoTrans; | ||
enum CBLAS_TRANSPOSE transB = trans_m ? CblasTrans : CblasNoTrans; | ||
|
||
/// shortcut handling in case of vector | ||
/// for vector, (1 * K) == (K * 1) in current memory layout... | ||
/// and plaese note that N, K, M is a fixed place holder after considering | ||
/// transpose. | ||
/// For example, there is no case like (1 * K) X (1 * K) while | ||
/// (1 * K) X (1 * M) can be a case | ||
/// case1: (1 * K) X (K * 1) | ||
if (M == 1 && N == 1) { | ||
*rdata = dot_cl(data, mdata, K, context) + (*rdata); | ||
} | ||
/// case2: (M * K) X (K * 1) | ||
else if (N == 1) { | ||
transA ? sgemv_cl(data, mdata, rdata, dim2, dim1, lda, context) | ||
: sgemv_cl(data, mdata, rdata, dim1, dim2, lda, context); | ||
} | ||
/// case3: (1 * K) X (K * N) = 1 * N = R | ||
/// = R^T = (K * N) ^T * (1 * K) ^T = (N * K) * (K * 1) = (N * K) * (1 * K) | ||
/// Effectively a translation of sgemv | ||
else if (M == 1) { | ||
transB = transB == CblasTrans ? CblasNoTrans : CblasTrans; | ||
transB ? sgemv_cl(mdata, data, rdata, mdim2, mdim1, ldb, context) | ||
: sgemv_cl(mdata, data, rdata, mdim1, mdim2, ldb, context); | ||
} | ||
/// case others: use gemm | ||
else { | ||
// transA == false, transB == false | ||
sgemm_cl(data, mdata, rdata, M, N, K, lda, ldb, ldc, context); | ||
// todo: other condition implementations | ||
} | ||
} else if (input.getDataType() == ml::train::TensorDim::DataType::FP16) { | ||
#ifdef ENABLE_FP16 | ||
const _FP16 *data = input.getData<_FP16>(); | ||
const _FP16 *mdata = m.getData<_FP16>(); | ||
_FP16 *rdata = result.getData<_FP16>(); | ||
enum CBLAS_TRANSPOSE transA = trans ? CblasTrans : CblasNoTrans; | ||
enum CBLAS_TRANSPOSE transB = trans_m ? CblasTrans : CblasNoTrans; | ||
|
||
/// shortcut handling in case of vector | ||
/// for vector, (1 * K) == (K * 1) in current memory layout... | ||
/// and plaese note that N, K, M is a fixed place holder after considering | ||
/// transpose. | ||
/// For example, there is no case like (1 * K) X (1 * K) while | ||
/// (1 * K) X (1 * M) can be a case | ||
/// case1: (1 * K) X (K * 1) | ||
if (M == 1 && N == 1) { | ||
*rdata = dot_cl(data, mdata, K, context) + (*rdata); | ||
} | ||
/// case2: (M * K) X (K * 1) | ||
else if (N == 1) { | ||
transA ? sgemv_cl(data, mdata, rdata, dim2, dim1, lda, context) | ||
: sgemv_cl(data, mdata, rdata, dim1, dim2, lda, context); | ||
} | ||
/// case3: (1 * K) X (K * N) = 1 * N = R | ||
/// = R^T = (K * N) ^T * (1 * K) ^T = (N * K) * (K * 1) = (N * K) * (1 * K) | ||
/// Effectively a translation of sgemv | ||
else if (M == 1) { | ||
transB = transB == CblasTrans ? CblasNoTrans : CblasTrans; | ||
transB ? sgemv_cl(mdata, data, rdata, mdim2, mdim1, ldb, context) | ||
: sgemv_cl(mdata, data, rdata, mdim1, mdim2, ldb, context); | ||
} | ||
/// case others: use sgemm | ||
else { | ||
// transA == false, transB == false | ||
sgemm_cl(data, mdata, rdata, M, N, K, lda, ldb, ldc, context); | ||
// todo: other condition implementations | ||
} | ||
#else | ||
throw std::invalid_argument("Error: enable-fp16 is not enabled"); | ||
#endif | ||
} | ||
} | ||
|
||
} // namespace nntrainer |
Oops, something went wrong.