From ee66e16dfdf9315d952907fae499cf6cdd282e27 Mon Sep 17 00:00:00 2001 From: skykongkong8 Date: Mon, 15 Apr 2024 13:11:24 +0900 Subject: [PATCH] [ Trivial ] Remove redundant comments and format - Due to adaptive macro kernel usage, previous comment is no longer needed. **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: skykongkong8 --- nntrainer/tensor/hgemm/hgemm.cpp | 4 +- nntrainer/tensor/hgemm/hgemm_kernel_4x4.h | 6 +-- nntrainer/tensor/hgemm/hgemm_kernel_4x8.h | 8 ++-- nntrainer/tensor/hgemm/hgemm_kernel_8x16.h | 50 +++++++++++----------- nntrainer/tensor/hgemm/hgemm_kernel_8x8.h | 8 ++-- 5 files changed, 37 insertions(+), 39 deletions(-) diff --git a/nntrainer/tensor/hgemm/hgemm.cpp b/nntrainer/tensor/hgemm/hgemm.cpp index 61353e595b..4aaadf331c 100644 --- a/nntrainer/tensor/hgemm/hgemm.cpp +++ b/nntrainer/tensor/hgemm/hgemm.cpp @@ -64,9 +64,9 @@ void hgemm_noTrans(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, hgemm_noTrans_4x8(M, N, K, A, K, B, N, C, N, alpha, beta); } else if ((M & 0x3) == 0 && (N & 0x3) == 0 && (K & 0x3) == 0) { hgemm_noTrans_4x4(M, N, K, A, K, B, N, C, N, alpha, beta); - } else if ((K & 0x7) == 0 && (N & 0x7) == 0) { + } else if ((N & 0x7) == 0 && (K & 0x7) == 0) { hgemm_noTrans_1x8(M, N, K, A, K, B, N, C, N, alpha, beta); - } else if ((K & 0x7) == 0 && (N & 0x3) == 0) { + } else if ((N & 0x3) == 0 && (K & 0x7) == 0) { hgemm_noTrans_1x4(M, N, K, A, K, B, N, C, N, alpha, beta); } } diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_4x4.h b/nntrainer/tensor/hgemm/hgemm_kernel_4x4.h index ab49faaca7..7bf75b13b7 100644 --- a/nntrainer/tensor/hgemm/hgemm_kernel_4x4.h +++ b/nntrainer/tensor/hgemm/hgemm_kernel_4x4.h @@ -20,7 +20,7 @@ v26 = vdup_n_f16(0.F); \ v27 = vdup_n_f16(0.F); -// 1. Partial sum 256 digits : Medium accuracy, medium latency +// 1. Partial sum 256 digits #define KERNEL_4x4_ACC16() \ dv0 = vld1_f16(a); \ vb0 = vld1_f16(b); \ @@ -124,7 +124,7 @@ b += 4 * 16; \ a += 4 * 16; -// 2. Partial sum 128 digits : Medium accuracy, medium latency +// 2. Partial sum 128 digits #define KERNEL_4x4_ACC8() \ dv0 = vld1_f16(a); \ vb0 = vld1_f16(b); \ @@ -180,7 +180,7 @@ b += 4 * 8; \ a += 4 * 8; -// 2. Partial sum 16 digits : Best accuracy, worst latency +// 2. Partial sum 16 digits #define KERNEL_4x4_ACC1() \ dv0 = vld1_f16(a); \ vb0 = vld1_f16(b); \ diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_4x8.h b/nntrainer/tensor/hgemm/hgemm_kernel_4x8.h index 064f0a7b73..01204457e9 100644 --- a/nntrainer/tensor/hgemm/hgemm_kernel_4x8.h +++ b/nntrainer/tensor/hgemm/hgemm_kernel_4x8.h @@ -20,7 +20,7 @@ v6 = vdupq_n_f16(0.F); \ v9 = vdupq_n_f16(0.F); -// 1. Partial sum 256 digits : worst accuracy, best latency +// 1. Partial sum 256 digits #define KERNEL_4x8_ACC16() \ dv0 = vld1_f16(a); \ v24 = vld1q_f16(b); \ @@ -124,7 +124,7 @@ b += 8 * 16; \ a += 4 * 16; -// 1. Partial sum 256 digits : worst accuracy, best latency +// 1. Partial sum 256 digits #define KERNEL_4x8_ACC8() \ dv0 = vld1_f16(a); \ v24 = vld1q_f16(b); \ @@ -180,7 +180,7 @@ b += 8 * 8; \ a += 4 * 8; -// 2. Partial sum 128 digits : medium accuracy, medium latency +// 2. Partial sum 128 digits #define KERNEL_4x8_ACC4() \ dv0 = vld1_f16(a); \ v24 = vld1q_f16(b); \ @@ -212,7 +212,7 @@ b += 8 * 4; \ a += 4 * 4; -// 3. Partial sum 32 digits : Best accuracy, worst latency +// 3. Partial sum 32 digits #define KERNEL_4x8_ACC1() \ dv0 = vld1_f16(a); \ v24 = vld1q_f16(b); \ diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_8x16.h b/nntrainer/tensor/hgemm/hgemm_kernel_8x16.h index 38778ea8f3..a89a6b5421 100644 --- a/nntrainer/tensor/hgemm/hgemm_kernel_8x16.h +++ b/nntrainer/tensor/hgemm/hgemm_kernel_8x16.h @@ -32,7 +32,7 @@ v112_119 = vdupq_n_f16(0.F); \ v120_127 = vdupq_n_f16(0.F); -// 0. Partial sum 2048 digits : Best latency, worst accuracy. +// 1. Partial sum 2048 digits #define KERNEL_8x16_ACC16() \ va0 = vld1q_f16(a); \ v24 = vld1q_f16(b); \ @@ -344,7 +344,7 @@ b += 16 * 16; \ a += 8 * 16; -// 1. Partial sum 1024 digits : Medium-high accuracy, medium latency +// 2. Partial sum 1024 digits #define KERNEL_8x16_ACC8() \ va0 = vld1q_f16(a); \ v24 = vld1q_f16(b); \ @@ -504,7 +504,7 @@ b += 16 * 8; \ a += 8 * 8; -// 2. Partial sum 512 digits : Medium accuracy, medium latency +// 3. Partial sum 512 digits #define KERNEL_8x16_ACC4() \ va0 = vld1q_f16(a); \ v24 = vld1q_f16(b); \ @@ -588,7 +588,7 @@ b += 16 * 4; \ a += 8 * 4; -// 3. Partial sum 128 digits : Best accuracy, worst latency +// 3. Partial sum 128 digits #define KERNEL_8x16_ACC1() \ va0 = vld1q_f16(a); \ v24 = vld1q_f16(b); \ @@ -740,28 +740,26 @@ void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K, for (; l < K;) { KERNEL_8x16_ACC1(); } - vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0_7)); - vst1q_f16(c + 8, vaddq_f16(vld1q_f16(c + 8), v64_71)); - vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v8_15)); - vst1q_f16(c + ldc + 8, vaddq_f16(vld1q_f16(c + ldc + 8), v72_79)); - vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v16_23)); - vst1q_f16(c + 2 * ldc + 8, - vaddq_f16(vld1q_f16(c + 2 * ldc + 8), v80_87)); - vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v24_31)); - vst1q_f16(c + 3 * ldc + 8, - vaddq_f16(vld1q_f16(c + 3 * ldc + 8), v88_95)); - vst1q_f16(c + 4 * ldc, vaddq_f16(vld1q_f16(c + 4 * ldc), v32_39)); - vst1q_f16(c + 4 * ldc + 8, - vaddq_f16(vld1q_f16(c + 4 * ldc + 8), v96_103)); - vst1q_f16(c + 5 * ldc, vaddq_f16(vld1q_f16(c + 5 * ldc), v40_47)); - vst1q_f16(c + 5 * ldc + 8, - vaddq_f16(vld1q_f16(c + 5 * ldc + 8), v104_111)); - vst1q_f16(c + 6 * ldc, vaddq_f16(vld1q_f16(c + 6 * ldc), v48_55)); - vst1q_f16(c + 6 * ldc + 8, - vaddq_f16(vld1q_f16(c + 6 * ldc + 8), v112_119)); - vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v56_63)); - vst1q_f16(c + 7 * ldc + 8, - vaddq_f16(vld1q_f16(c + 7 * ldc + 8), v120_127)); + vst1q_f16(c, vaddq_f16(vld1q_f16(c), v0_7)); + vst1q_f16(c + 8, vaddq_f16(vld1q_f16(c + 8), v64_71)); + vst1q_f16(c + ldc, vaddq_f16(vld1q_f16(c + ldc), v8_15)); + vst1q_f16(c + ldc + 8, vaddq_f16(vld1q_f16(c + ldc + 8), v72_79)); + vst1q_f16(c + 2 * ldc, vaddq_f16(vld1q_f16(c + 2 * ldc), v16_23)); + vst1q_f16(c + 2 * ldc + 8, vaddq_f16(vld1q_f16(c + 2 * ldc + 8), v80_87)); + vst1q_f16(c + 3 * ldc, vaddq_f16(vld1q_f16(c + 3 * ldc), v24_31)); + vst1q_f16(c + 3 * ldc + 8, vaddq_f16(vld1q_f16(c + 3 * ldc + 8), v88_95)); + vst1q_f16(c + 4 * ldc, vaddq_f16(vld1q_f16(c + 4 * ldc), v32_39)); + vst1q_f16(c + 4 * ldc + 8, + vaddq_f16(vld1q_f16(c + 4 * ldc + 8), v96_103)); + vst1q_f16(c + 5 * ldc, vaddq_f16(vld1q_f16(c + 5 * ldc), v40_47)); + vst1q_f16(c + 5 * ldc + 8, + vaddq_f16(vld1q_f16(c + 5 * ldc + 8), v104_111)); + vst1q_f16(c + 6 * ldc, vaddq_f16(vld1q_f16(c + 6 * ldc), v48_55)); + vst1q_f16(c + 6 * ldc + 8, + vaddq_f16(vld1q_f16(c + 6 * ldc + 8), v112_119)); + vst1q_f16(c + 7 * ldc, vaddq_f16(vld1q_f16(c + 7 * ldc), v56_63)); + vst1q_f16(c + 7 * ldc + 8, + vaddq_f16(vld1q_f16(c + 7 * ldc + 8), v120_127)); c += 16; a -= 8 * K; } diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_8x8.h b/nntrainer/tensor/hgemm/hgemm_kernel_8x8.h index c913bdd040..4901c3f518 100644 --- a/nntrainer/tensor/hgemm/hgemm_kernel_8x8.h +++ b/nntrainer/tensor/hgemm/hgemm_kernel_8x8.h @@ -24,7 +24,7 @@ v30 = vdupq_n_f16(0.F); \ v31 = vdupq_n_f16(0.F); -// 1. Partial sum 1024 digits : Worst accuracy, best latency +// 1. Partial sum 1024 digits #define KERNEL_8x8_ACC16() \ va0 = vld1q_f16(a); \ v16 = vld1q_f16(b); \ @@ -192,7 +192,7 @@ b += 8 * 16; \ a += 8 * 16; -// 2. Partial sum 512 digits : Medium accuracy, medium latency +// 2. Partial sum 512 digits #define KERNEL_8x8_ACC8() \ va0 = vld1q_f16(a); \ v16 = vld1q_f16(b); \ @@ -280,7 +280,7 @@ b += 8 * 8; \ a += 8 * 8; -// 3. Partial sum 256 digits : Medium accuracy, medium latency +// 3. Partial sum 256 digits #define KERNEL_8x8_ACC4() \ va0 = vld1q_f16(a); \ v16 = vld1q_f16(b); \ @@ -328,7 +328,7 @@ b += 8 * 4; \ a += 8 * 4; -// 4. Partial sum 64 digits : Best accuracy, worst latency +// 4. Partial sum 64 digits #define KERNEL_8x8_ACC1() \ va0 = vld1q_f16(a); \ v16 = vld1q_f16(b); \