From b801d400d32199a359f91c0d580bdcbd82e1abb5 Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Tue, 3 Sep 2024 18:55:50 +0200 Subject: [PATCH] [WIP][DRAFT][onert] Optimized BatchMatMul in CPU backend This commit introduces improved BMM kernel for CPU. ONE-DCO-1.0-Signed-off-by: Jan Iwaszkiewicz --- .../cker/operation/reference/BatchMatMul.h | 55 ++++++++++++++----- 1 file changed, 41 insertions(+), 14 deletions(-) diff --git a/compute/cker/include/cker/operation/reference/BatchMatMul.h b/compute/cker/include/cker/operation/reference/BatchMatMul.h index 1b3020de22c..1e943422acf 100644 --- a/compute/cker/include/cker/operation/reference/BatchMatMul.h +++ b/compute/cker/include/cker/operation/reference/BatchMatMul.h @@ -20,6 +20,7 @@ #include "cker/Types.h" #include "cker/Shape.h" +#include "cker/operation/optimized/Gemm.h" namespace nnfw { @@ -73,7 +74,7 @@ inline void BatchMatMul(const Shape &lhs_shape, const float *lhs_data, const Sha // Set params for each matrix multiply. const int lhs_rows = extended_lhs_shape.Dims(3); const int rhs_cols = extended_rhs_shape.Dims(4); - const int accum_depth = extended_lhs_shape.Dims(4); + // const int accum_depth = extended_lhs_shape.Dims(4); for (int b0 = 0; b0 < batch_dim0; ++b0) { @@ -89,19 +90,45 @@ inline void BatchMatMul(const Shape &lhs_shape, const float *lhs_data, const Sha const float *rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2; float *out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) + b1 * batch_dim2 + b2) * lhs_rows * rhs_cols; - for (int j = 0; j < rhs_cols; ++j) - { - for (int i = 0; i < lhs_rows; ++i) - { - float total = 0.f; - for (int k = 0; k < accum_depth; ++k) - { - total += lhs_ptr2[accum_depth * i + k] * rhs_ptr2[j * accum_depth + k]; - } - int idx = lhs_rows * j + i; - out_ptr[idx] = total; - } - } + + MatrixParams rhs_params; + rhs_params.order = Order::kColMajor; // should be always like this? base it of adj_x & adj_y? + rhs_params.rows = rhs_cols; + rhs_params.cols = lhs_rows; + // How to determine this? + rhs_params.cache_policy = nnfw::cker::optimized::DefaultCachePolicy(false); + + MatrixParams lhs_params; + lhs_params.order = Order::kRowMajor; // should be always like this? base it of adj_x & adj_y? + lhs_params.rows = lhs_rows; + lhs_params.cols = rhs_cols; + // How to determine this? + lhs_params.cache_policy = nnfw::cker::optimized::DefaultCachePolicy(false); + + MatrixParams dst_params; + dst_params.order = Order::kColMajor; + dst_params.rows = lhs_rows; + dst_params.cols = rhs_cols; + + GemmParams gemm_params; + + nnfw::cker::optimized::Gemm(lhs_params, lhs_ptr2, rhs_params, rhs_ptr2, dst_params, out_ptr, + gemm_params); + + // for (int j = 0; j < rhs_cols; ++j) + // { + // for (int i = 0; i < lhs_rows; ++i) + // { + // float total = 0.f; + // for (int k = 0; k < accum_depth; ++k) + // { + // total += lhs_ptr2[accum_depth * i + k] * rhs_ptr2[j * accum_depth + k]; + // } + // int idx = lhs_rows * j + i; + // out_ptr[idx] = total; + // } + // } + } } }