-
Notifications
You must be signed in to change notification settings - Fork 82
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ BLAS ] Support non-4-divisible case in matrix transpose
- Previously, there was a code defect when transposing matrix with non-4-divisible col length. - Bugfix and refactor its using interface: move transpose fallback when NEON is supported. **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: skykongkong8 <[email protected]>
- Loading branch information
1 parent
c3b6175
commit 4efa98b
Showing
5 changed files
with
120 additions
and
184 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
/** | ||
* Copyright (C) 2024 Sungsik Kong <[email protected]> | ||
* | ||
* @file transpose_utils_neon.h | ||
* @file matrix_transpose_kernels_neon.h | ||
* @date 09 May 2024 | ||
* @see https://github.com/nnstreamer/nntrainer | ||
* @author Sungsik Kong <[email protected]> | ||
|
@@ -14,7 +14,7 @@ | |
#include <arm_neon.h> | ||
#include <cassert> | ||
#include <cstdint> | ||
#include "./mask_neon.h" | ||
#include <mask_neon.h> | ||
|
||
#define TRANSPOSE_FP16_4x4(row0, row1, row2, row3) \ | ||
float16x4x2_t row01 = vtrn_f16(row0, row1); \ | ||
|
@@ -30,7 +30,6 @@ | |
vcvt_f16_f32(vcombine_f32(vget_high_f32(vcvt_f32_f16(row01.val[1])), \ | ||
vget_high_f32(vcvt_f32_f16(row23.val[1])))); | ||
|
||
|
||
static inline void transpose_kernel_4x4_neon(const __fp16 *src, | ||
unsigned int ld_src, __fp16 *dst, | ||
unsigned int ld_dst) { | ||
|
@@ -52,8 +51,8 @@ static void transpose_kernel_mxn_neon_128(unsigned int N, const __fp16 *src, | |
unsigned int ld_src, __fp16 *dst, | ||
unsigned int ld_dst) { | ||
|
||
|
||
uint16x4_t bitmask_v8 = vld1_u16(reinterpret_cast<const uint16_t *>(masks[N])); | ||
uint16x4_t bitmask_v8 = | ||
vld1_u16(reinterpret_cast<const uint16_t *>(masks[N])); | ||
float16x4_t input[4]; | ||
float16x4_t ZEROS = vmov_n_f16(0.F); | ||
|
||
|
@@ -81,13 +80,13 @@ static void transpose_kernel_mxn_neon_128(unsigned int N, const __fp16 *src, | |
vcvt_f16_f32(vcombine_f32(vget_low_f32(vcvt_f32_f16(temp[i / 2])), | ||
vget_low_f32(vcvt_f32_f16(temp[2 + i / 2])))); | ||
} else { | ||
input[i] = | ||
vcvt_f16_f32(vcombine_f32(vget_high_f32(vcvt_f32_f16(temp[i / 2])), | ||
vget_high_f32(vcvt_f32_f16(temp[2 + i / 2])))); | ||
input[i] = vcvt_f16_f32( | ||
vcombine_f32(vget_high_f32(vcvt_f32_f16(temp[i / 2])), | ||
vget_high_f32(vcvt_f32_f16(temp[2 + i / 2])))); | ||
} | ||
vst1_f16(&dst[i * ld_dst], vbsl_f16(bitmask_v8, input[i], ZEROS)); | ||
vst1_f16(&dst[i * ld_dst], | ||
vbsl_f16(bitmask_v8, input[i], vld1_f16(&dst[i * ld_dst]))); | ||
} | ||
|
||
} | ||
|
||
static inline void transpose_kernel_8x8_neon(const __fp16 *src, | ||
|
@@ -126,7 +125,7 @@ static inline void transpose_kernel_8x8_neon(const __fp16 *src, | |
vld1q_u16(reinterpret_cast<const uint16_t *>(shuffle_masks)); | ||
abcd04 = vbslq_f16(shuffle_mask, ab0145, vextq_f16(cd0145, cd0145, 6)); | ||
abcd15 = vbslq_f16(shuffle_mask, vextq_f16(ab0145, ab0145, 2), cd0145); | ||
|
||
efgh04 = vbslq_f16(shuffle_mask, ef0145, vextq_f16(gh0145, gh0145, 6)); | ||
efgh15 = vbslq_f16(shuffle_mask, vextq_f16(ef0145, ef0145, 2), gh0145); | ||
|
||
|
@@ -197,14 +196,17 @@ static void transpose_kernel_mxn_neon_256(unsigned int N, const __fp16 *src, | |
vbslq_f16(shuffle_mask, vextq_f16(temp[4 * i + 1], temp[4 * i + 1], 2), | ||
temp[4 * i + 3]); | ||
} | ||
bitmask_v8 = vld1q_u16( | ||
reinterpret_cast<const uint16_t *>(neon_16bit_masks[M])); | ||
bitmask_v8 = | ||
vld1q_u16(reinterpret_cast<const uint16_t *>(neon_16bit_masks[M])); | ||
for (i = 0; i < N; ++i) { | ||
if (i < 4) { | ||
temp[i] = vcombine_f16(vget_low_f16(input[i]), vget_low_f16(input[4 + i])); | ||
temp[i] = | ||
vcombine_f16(vget_low_f16(input[i]), vget_low_f16(input[4 + i])); | ||
} else { | ||
temp[i] = vcombine_f16(vget_high_f16(input[i - 4]), vget_high_f16(input[i])); | ||
temp[i] = | ||
vcombine_f16(vget_high_f16(input[i - 4]), vget_high_f16(input[i])); | ||
} | ||
vst1q_f16(&dst[i * ld_dst], vbslq_f16(bitmask_v8, temp[i], ZEROS)); | ||
vst1q_f16(&dst[i * ld_dst], | ||
vbslq_f16(bitmask_v8, temp[i], vld1q_f16(&dst[i * ld_dst]))); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.