diff --git a/compute/cker/include/cker/operation/BatchMatMul.h b/compute/cker/include/cker/operation/BatchMatMul.h index 18070982ad2..ca2ddb8fa8e 100644 --- a/compute/cker/include/cker/operation/BatchMatMul.h +++ b/compute/cker/include/cker/operation/BatchMatMul.h @@ -23,6 +23,7 @@ #include "cker/Types.h" #include "cker/Shape.h" #include "cker/Utils.h" +#include "cker/operation/optimized/BatchMatMul.h" #include "cker/operation/reference/BatchMatMul.h" #include @@ -77,7 +78,7 @@ class BatchMatMul } void operator()(const Shape &lhs_shape, const float *lhs_data, const Shape &rhs_shape, - const float *rhs_data, bool adj_x, bool adj_y, const Shape &output_shape, + const float *rhs_data, bool adj_x, bool adj_y, const Shape & /*output_shape*/, float *output_data) { // Assume lhs and rhs is not constant @@ -102,8 +103,13 @@ class BatchMatMul // Check accumulative dimensions of lhs and rhs of are equal assert(Shape::ExtendedShape(5, new_rhs_shape).Dims(4) == Shape::ExtendedShape(5, new_lhs_shape).Dims(3)); - reference::BatchMatMul(new_rhs_shape, new_rhs_data, new_lhs_shape, new_lhs_data, output_shape, - output_data); + + const BatchMatMulParams params{new_rhs_shape, new_lhs_shape}; + #if defined(CKER_X86_PLATFORM) + optimized::BatchMatMul(params, new_rhs_data, new_lhs_data, output_data); + #else + reference::BatchMatMul(params, new_rhs_data, new_lhs_data, output_data); + #endif } private: diff --git a/compute/cker/include/cker/operation/Helper/BatchMatMulParams.h b/compute/cker/include/cker/operation/Helper/BatchMatMulParams.h new file mode 100644 index 00000000000..3949865bac6 --- /dev/null +++ b/compute/cker/include/cker/operation/Helper/BatchMatMulParams.h @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_CKER_HELPER_BATCH_MAT_MUL_PARAMS_H__ +#define __NNFW_CKER_HELPER_BATCH_MAT_MUL_PARAMS_H__ + +#include "cker/Shape.h" + +namespace nnfw +{ +namespace cker +{ +struct BatchMatMulParams +{ + BatchMatMulParams(const Shape &lhs_shape, const Shape &rhs_shape) + { + const Shape extended_lhs_shape = Shape::ExtendedShape(5, lhs_shape); + const Shape extended_rhs_shape = Shape::ExtendedShape(5, rhs_shape); + + batch_dim0 = broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0)); + batch_dim1 = broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1)); + batch_dim2 = broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2)); + + lhs_ext0 = extent(extended_lhs_shape, 0); + lhs_ext1 = extent(extended_lhs_shape, 1); + lhs_ext2 = extent(extended_lhs_shape, 2); + rhs_ext0 = extent(extended_rhs_shape, 0); + rhs_ext1 = extent(extended_rhs_shape, 1); + rhs_ext2 = extent(extended_rhs_shape, 2); + + // Set params for each matrix multiply. + lhs_rows = extended_lhs_shape.Dims(3); + lhs_cols = extended_lhs_shape.Dims(4); + rhs_rows = extended_rhs_shape.Dims(3); + rhs_cols = extended_rhs_shape.Dims(4); + accum_depth = extended_lhs_shape.Dims(4); + } + + int batch_dim0; + int batch_dim1; + int batch_dim2; + int lhs_ext0; + int lhs_ext1; + int lhs_ext2; + int rhs_ext0; + int rhs_ext1; + int rhs_ext2; + int lhs_rows; + int lhs_cols; + int rhs_rows; + int rhs_cols; + int accum_depth; + +private: + // Determines which dimension is the broadcast dimension. + int32_t broadcast_dim(int32_t lhs_dim, int32_t rhs_dim) + { + if (lhs_dim == rhs_dim) + return lhs_dim; + if (lhs_dim == 1) + return rhs_dim; + assert(rhs_dim == 1); + return lhs_dim; + }; + + // Computes the "extent" for iterating on this dimension. + // If we are broadcasting, then don't advance (i.e return 0). + int extent(const Shape &shape, int x) + { + if (shape.Dims(x) == 1) + { + return 0; + } + int prod = 1; + for (int i = x + 1; i < shape.DimensionsCount(); ++i) + { + prod *= shape.Dims(i); + } + return prod; + }; +}; +} // namespace cker +} // namespace nnfw + +#endif diff --git a/compute/cker/include/cker/operation/optimized/BatchMatMul.h b/compute/cker/include/cker/operation/optimized/BatchMatMul.h new file mode 100644 index 00000000000..66d17b82821 --- /dev/null +++ b/compute/cker/include/cker/operation/optimized/BatchMatMul.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_CKER_OPTIMIZED_BATCH_MATMUL_H__ +#define __NNFW_CKER_OPTIMIZED_BATCH_MATMUL_H__ + +#include "cker/Shape.h" +#include "cker/operation/Helper/BatchMatMulParams.h" +#include "Gemm.h" + +namespace nnfw +{ +namespace cker +{ +namespace optimized +{ +inline void BatchMatMul(const BatchMatMulParams ¶ms, const float *lhs_data, + const float *rhs_data, float *output_data) +{ + MatrixParams lhs_params; + lhs_params.order = Order::kRowMajor; + lhs_params.rows = params.lhs_rows; + lhs_params.cols = params.lhs_cols; + + MatrixParams rhs_params; + rhs_params.order = Order::kRowMajor; + rhs_params.rows = params.rhs_rows; + rhs_params.cols = params.rhs_cols; + + MatrixParams dst_params; + dst_params.order = Order::kRowMajor; + dst_params.rows = params.lhs_rows; + dst_params.cols = params.rhs_cols; + + for (int b0 = 0; b0 < params.batch_dim0; ++b0) + { + for (int b1 = 0; b1 < params.batch_dim1; ++b1) + { + for (int b2 = 0; b2 < params.batch_dim2; ++b2) + { + const float *lhs_ptr = + lhs_data + b0 * params.lhs_ext0 + b1 * params.lhs_ext1 + b2 * params.lhs_ext2; + const float *rhs_ptr = + rhs_data + b0 * params.rhs_ext0 + b1 * params.rhs_ext1 + b2 * params.rhs_ext2; + float *out_ptr = output_data + ((b0 * params.batch_dim1 * params.batch_dim2) + + b1 * params.batch_dim2 + b2) * + params.lhs_rows * params.rhs_cols; + + Gemm(lhs_params, lhs_ptr, rhs_params, rhs_ptr, dst_params, out_ptr, + GemmParams{}); + } + } + } +} +} // namespace optimized +} // namespace cker +} // namespace nnfw + +#endif // __NNFW_CKER_OPTIMIZED_BATCH_MATMUL_H__ diff --git a/compute/cker/include/cker/operation/reference/BatchMatMul.h b/compute/cker/include/cker/operation/reference/BatchMatMul.h index 1b3020de22c..4060bb9365a 100644 --- a/compute/cker/include/cker/operation/reference/BatchMatMul.h +++ b/compute/cker/include/cker/operation/reference/BatchMatMul.h @@ -20,6 +20,7 @@ #include "cker/Types.h" #include "cker/Shape.h" +#include "cker/operation/Helper/BatchMatMulParams.h" namespace nnfw { @@ -28,77 +29,34 @@ namespace cker namespace reference { -inline void BatchMatMul(const Shape &lhs_shape, const float *lhs_data, const Shape &rhs_shape, - const float *rhs_data, const Shape &, float *output_data) +inline void BatchMatMul(const BatchMatMulParams ¶ms, const float *lhs_data, + const float *rhs_data, float *output_data) { - const Shape extended_lhs_shape = Shape::ExtendedShape(5, lhs_shape); - const Shape extended_rhs_shape = Shape::ExtendedShape(5, rhs_shape); - - // Determine which dimension is the broadcast dimension. - auto broadcast_dim = [](int lhs_dim, int rhs_dim) { - if (lhs_dim == rhs_dim) - return lhs_dim; - if (lhs_dim == 1) - return rhs_dim; - assert(rhs_dim == 1); - return lhs_dim; - }; - - // Compute the "extent" for iterating on this dimension. - // If we are broadcasting, then don't advance (i.e return 0). - auto extent = [](const Shape &shape, int x) { - if (shape.Dims(x) == 1) - { - return 0; - } - int prod = 1; - for (int i = x + 1; i < shape.DimensionsCount(); ++i) - { - prod *= shape.Dims(i); - } - return prod; - }; - - const int batch_dim0 = broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0)); - const int batch_dim1 = broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1)); - const int batch_dim2 = broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2)); - - const int lhs_ext0 = extent(extended_lhs_shape, 0); - const int lhs_ext1 = extent(extended_lhs_shape, 1); - const int lhs_ext2 = extent(extended_lhs_shape, 2); - const int rhs_ext0 = extent(extended_rhs_shape, 0); - const int rhs_ext1 = extent(extended_rhs_shape, 1); - const int rhs_ext2 = extent(extended_rhs_shape, 2); - - // Set params for each matrix multiply. - const int lhs_rows = extended_lhs_shape.Dims(3); - const int rhs_cols = extended_rhs_shape.Dims(4); - const int accum_depth = extended_lhs_shape.Dims(4); - - for (int b0 = 0; b0 < batch_dim0; ++b0) + for (int b0 = 0; b0 < params.batch_dim0; ++b0) { - const float *lhs_ptr0 = lhs_data + (b0 * lhs_ext0); - const float *rhs_ptr0 = rhs_data + (b0 * rhs_ext0); - for (int b1 = 0; b1 < batch_dim1; ++b1) + const float *lhs_ptr0 = lhs_data + (b0 * params.lhs_ext0); + const float *rhs_ptr0 = rhs_data + (b0 * params.rhs_ext0); + for (int b1 = 0; b1 < params.batch_dim1; ++b1) { - const float *lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1; - const float *rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1; - for (int b2 = 0; b2 < batch_dim2; ++b2) + const float *lhs_ptr1 = lhs_ptr0 + b1 * params.lhs_ext1; + const float *rhs_ptr1 = rhs_ptr0 + b1 * params.rhs_ext1; + for (int b2 = 0; b2 < params.batch_dim2; ++b2) { - const float *lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2; - const float *rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2; - float *out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) + b1 * batch_dim2 + b2) * - lhs_rows * rhs_cols; - for (int j = 0; j < rhs_cols; ++j) + const float *lhs_ptr2 = lhs_ptr1 + b2 * params.lhs_ext2; + const float *rhs_ptr2 = rhs_ptr1 + b2 * params.rhs_ext2; + float *out_ptr = output_data + ((b0 * params.batch_dim1 * params.batch_dim2) + + b1 * params.batch_dim2 + b2) * + params.lhs_rows * params.rhs_cols; + for (int j = 0; j < params.rhs_cols; ++j) { - for (int i = 0; i < lhs_rows; ++i) + for (int i = 0; i < params.lhs_rows; ++i) { float total = 0.f; - for (int k = 0; k < accum_depth; ++k) + for (int k = 0; k < params.accum_depth; ++k) { - total += lhs_ptr2[accum_depth * i + k] * rhs_ptr2[j * accum_depth + k]; + total += lhs_ptr2[params.accum_depth * i + k] * rhs_ptr2[j * params.accum_depth + k]; } - int idx = lhs_rows * j + i; + int idx = params.lhs_rows * j + i; out_ptr[idx] = total; } }