diff --git a/nntrainer/tensor/hgemm/hgemm.cpp b/nntrainer/tensor/hgemm/hgemm.cpp index a41a5ba6dc..faf6d21cbe 100644 --- a/nntrainer/tensor/hgemm/hgemm.cpp +++ b/nntrainer/tensor/hgemm/hgemm.cpp @@ -32,15 +32,17 @@ void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M, unsigned int N, unsigned int K, float alpha, float beta) { if (alpha == 1.F && beta == 0.F) { - if (M % 8 == 0 && N % 16 == 0 && K % 8 == 0) { + // used bitwise operator instead of modulo for performance + // e.g (M % 8) is same as (M & 0x7) which will extract last 3 bits of M + if ((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0) { hgemm_noTrans_8x16(M, N, K, A, K, B, N, C32, N, alpha, beta); - } else if (M % 8 == 0 && N % 8 == 0 && K % 8 == 0) { + } else if ((M & 0x7) == 0 && (N & 0x7) == 0 && (K & 0x7) == 0) { hgemm_noTrans_8x8(M, N, K, A, K, B, N, C32, N, alpha, beta); - } else if (M % 4 == 0 && N % 8 == 0 && K % 4 == 0) { + } else if ((M & 0x3) == 0 && (N & 0x7) == 0 && (K & 0x3) == 0) { hgemm_noTrans_4x8(M, N, K, A, K, B, N, C32, N, alpha, beta); - } else if (N % 8 == 0) { + } else if ((K & 0x7) == 0 && (N & 0x7) == 0) { hgemm_noTrans_1x8(M, N, K, A, K, B, N, C32, N, alpha, beta); - } else if (N % 4 == 0) { + } else if ((K & 0x7) == 0 && (N & 0x3) == 0) { hgemm_noTrans_1x4(M, N, K, A, K, B, N, C32, N, alpha, beta); } else { hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta); @@ -52,17 +54,19 @@ void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M, void hgemm_noTrans(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, unsigned int N, unsigned int K, float alpha, float beta) { if (alpha == 1.F && beta == 0.F) { - if (M % 8 == 0 && N % 16 == 0 && K % 8 == 0) { + // used bitwise operator instead of modulo for performance + // e.g (M % 8) is same as (M & 0x7) which will extract last 3 bits of M + if ((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0) { hgemm_noTrans_8x16(M, N, K, A, K, B, N, C, N, alpha, beta); - } else if (M % 8 == 0 && N % 8 == 0 && K % 8 == 0) { + } else if ((M & 0x7) == 0 && (N & 0x7) == 0 && (K & 0x7) == 0) { hgemm_noTrans_8x8(M, N, K, A, K, B, N, C, N, alpha, beta); - } else if (M % 4 == 0 && N % 8 == 0 && K % 4 == 0) { + } else if ((M & 0x3) == 0 && (N & 0x7) == 0 && (K & 0x3) == 0) { hgemm_noTrans_4x8(M, N, K, A, K, B, N, C, N, alpha, beta); - } else if (N % 8 == 0) { - hgemm_noTrans_1x8(M, N, K, A, K, B, N, C, N, alpha, beta); - } else if (M % 4 == 0 && N % 4 == 0 && K % 4 == 0) { + } else if ((M & 0x3) == 0 && (N & 0x3) == 0 && (K & 0x3) == 0) { hgemm_noTrans_4x4(M, N, K, A, K, B, N, C, N, alpha, beta); - } else if (N % 4 == 0) { + } else if ((K & 0x7) == 0 && (N & 0x7) == 0) { + hgemm_noTrans_1x8(M, N, K, A, K, B, N, C, N, alpha, beta); + } else if ((K & 0x7) == 0 && (N & 0x3) == 0) { hgemm_noTrans_1x4(M, N, K, A, K, B, N, C, N, alpha, beta); } } diff --git a/test/unittest/unittest_nntrainer_tensor_neon_fp16.cpp b/test/unittest/unittest_nntrainer_tensor_neon_fp16.cpp index e02eac1786..ff55ca543d 100644 --- a/test/unittest/unittest_nntrainer_tensor_neon_fp16.cpp +++ b/test/unittest/unittest_nntrainer_tensor_neon_fp16.cpp @@ -701,6 +701,128 @@ TEST(nntrainer_Tensor, dot_gemm_50_768_20000) { EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1); } +TEST(nntrainer_Tensor, dot_gemm_512_520_1032) { + /// @note GEMM : A X B = C + int batch = 1; + int channel = 1; + int height = 512; + int width = 520; + + int height_b = 520; + int width_b = 1032; + + bool transA = false; + bool transB = false; + + nntrainer::TensorDim::TensorType t_type_nchw_fp16 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + nntrainer::Tensor A(batch, channel, height, width, t_type_nchw_fp16); + nntrainer::Tensor B(batch, channel, height_b, width_b, t_type_nchw_fp16); + + nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32); + + const float alpha = 1e-1; + const int MOD = 10; + + GEN_TEST_INPUT(A, ((i * (batch * height * channel) + j * (batch * height) + + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) + + j * (batch * height) + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + + nntrainer::Tensor C = A.dot(B, transA, transB); + + nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, transA, transB); + + float mseErrorNeon = + mse<__fp16>(C.getData<__fp16>(), C_fp32.getData(), C.size()); + + double cosSimNeon = cosine_similarity<__fp16>( + C.getData<__fp16>(), C_fp32.getData(), C.size()); + + const float epsilon = 1e-3 * width; + + EXPECT_IN_RANGE(mseErrorNeon, 0, epsilon); + EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1); +} + +TEST(nntrainer_Tensor, dot_gemm_1001_1024_20000) { + /// @note GEMM : A X B = C + int batch = 1; + int channel = 1; + int height = 1001; + int width = 1024; + + int height_b = 1024; + int width_b = 20000; + + bool transA = false; + bool transB = false; + + nntrainer::TensorDim::TensorType t_type_nchw_fp16 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}; + + nntrainer::TensorDim::TensorType t_type_nchw_fp32 = { + nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}; + + nntrainer::Tensor A(batch, channel, height, width, t_type_nchw_fp16); + nntrainer::Tensor B(batch, channel, height_b, width_b, t_type_nchw_fp16); + + nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32); + nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32); + + const float alpha = 1e-1; + const int MOD = 10; + + GEN_TEST_INPUT(A, ((i * (batch * height * channel) + j * (batch * height) + + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) + + j * (batch * height) + k * (width) + l + 1) % + MOD) * + alpha); + GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) + + j * (batch * height_b) + k * (width_b) + l + 1) % + MOD) * + alpha); + + nntrainer::Tensor C = A.dot(B, transA, transB); + + nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, transA, transB); + + float mseErrorNeon = + mse<__fp16>(C.getData<__fp16>(), C_fp32.getData(), C.size()); + + double cosSimNeon = cosine_similarity<__fp16>( + C.getData<__fp16>(), C_fp32.getData(), C.size()); + + const float epsilon = 1e-3 * width; + + EXPECT_IN_RANGE(mseErrorNeon, 0, epsilon); + EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1); +} + TEST(nntrainer_Tensor, dot_gemm_50_768_516) { /// @note GEMM : A X B = C int batch = 1;