Skip to content

Commit

Permalink
[ BLAS ] Support non-4-divisible case in matrix transpose
Browse files Browse the repository at this point in the history
- Previously, there was a code defect when transposing matrix with non-4-divisible col length.
- Bugfix and refactor its using interface: move transpose fallback when NEON is supported.

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <[email protected]>
  • Loading branch information
skykongkong8 committed May 23, 2024
1 parent c3b6175 commit 4efa98b
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 184 deletions.
31 changes: 17 additions & 14 deletions nntrainer/tensor/blas_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,21 @@

namespace nntrainer {

template <typename T>
static inline void transpose_fallback(
unsigned int M,
unsigned int N,
const T* src,
unsigned int ld_src,
T* dst,
unsigned int ld_dst) {
for (unsigned int j = 0; j < N; j++) {
for (unsigned int i = 0; i < M; i++) {
dst[i + j * ld_dst] = src[i * ld_src + j];
}
}
}

#ifdef ENABLE_FP16
static void saxpy_FP16(const unsigned int N, const float alpha, const _FP16 *X,
const int incX, _FP16 *Y, const int incY) {
Expand Down Expand Up @@ -535,21 +550,9 @@ void transpose_matrix(const unsigned int M, const unsigned int N,
const _FP16 *src, unsigned int ld_src, _FP16 *dst,
unsigned int ld_dst) {
#ifdef USE_NEON
/// @note Final form of transpose_neon is NOT having fallback. Debugging WIP.
if ((M & 0x3) == 0) {
transpose_neon<_FP16>(M, N, src, ld_src, dst, ld_dst);
} else {
transpose_fallback<_FP16>(M, N, src, ld_src, dst, ld_dst);
}
transpose_neon<_FP16>(M, N, src, ld_src, dst, ld_dst);
#else
/// @note This code should be replaced with:
/// transpose_fallback<_FP16>(M, N, src, ld_src, dst, ld_dst);
/// during arch-dep freeing refactorization.
for (unsigned int j = 0; j < N; j++) {
for (unsigned int i = 0; i < M; i++) {
dst[i + j * ld_dst] = src[i * ld_src + j];
}
}
transpose_fallback<_FP16>(M, N, src, ld_src, dst, ld_dst);
#endif
}
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
/**
* Copyright (C) 2024 Sungsik Kong <[email protected]>
*
* @file transpose_utils_neon.h
* @file matrix_transpose_kernels_neon.h
* @date 09 May 2024
* @see https://github.com/nnstreamer/nntrainer
* @author Sungsik Kong <[email protected]>
Expand All @@ -14,7 +14,7 @@
#include <arm_neon.h>
#include <cassert>
#include <cstdint>
#include "./mask_neon.h"
#include <mask_neon.h>

#define TRANSPOSE_FP16_4x4(row0, row1, row2, row3) \
float16x4x2_t row01 = vtrn_f16(row0, row1); \
Expand All @@ -30,7 +30,6 @@
vcvt_f16_f32(vcombine_f32(vget_high_f32(vcvt_f32_f16(row01.val[1])), \
vget_high_f32(vcvt_f32_f16(row23.val[1]))));


static inline void transpose_kernel_4x4_neon(const __fp16 *src,
unsigned int ld_src, __fp16 *dst,
unsigned int ld_dst) {
Expand All @@ -52,8 +51,8 @@ static void transpose_kernel_mxn_neon_128(unsigned int N, const __fp16 *src,
unsigned int ld_src, __fp16 *dst,
unsigned int ld_dst) {


uint16x4_t bitmask_v8 = vld1_u16(reinterpret_cast<const uint16_t *>(masks[N]));
uint16x4_t bitmask_v8 =
vld1_u16(reinterpret_cast<const uint16_t *>(masks[N]));
float16x4_t input[4];
float16x4_t ZEROS = vmov_n_f16(0.F);

Expand Down Expand Up @@ -81,13 +80,13 @@ static void transpose_kernel_mxn_neon_128(unsigned int N, const __fp16 *src,
vcvt_f16_f32(vcombine_f32(vget_low_f32(vcvt_f32_f16(temp[i / 2])),
vget_low_f32(vcvt_f32_f16(temp[2 + i / 2]))));
} else {
input[i] =
vcvt_f16_f32(vcombine_f32(vget_high_f32(vcvt_f32_f16(temp[i / 2])),
vget_high_f32(vcvt_f32_f16(temp[2 + i / 2]))));
input[i] = vcvt_f16_f32(
vcombine_f32(vget_high_f32(vcvt_f32_f16(temp[i / 2])),
vget_high_f32(vcvt_f32_f16(temp[2 + i / 2]))));
}
vst1_f16(&dst[i * ld_dst], vbsl_f16(bitmask_v8, input[i], ZEROS));
vst1_f16(&dst[i * ld_dst],
vbsl_f16(bitmask_v8, input[i], vld1_f16(&dst[i * ld_dst])));
}

}

static inline void transpose_kernel_8x8_neon(const __fp16 *src,
Expand Down Expand Up @@ -126,7 +125,7 @@ static inline void transpose_kernel_8x8_neon(const __fp16 *src,
vld1q_u16(reinterpret_cast<const uint16_t *>(shuffle_masks));
abcd04 = vbslq_f16(shuffle_mask, ab0145, vextq_f16(cd0145, cd0145, 6));
abcd15 = vbslq_f16(shuffle_mask, vextq_f16(ab0145, ab0145, 2), cd0145);

efgh04 = vbslq_f16(shuffle_mask, ef0145, vextq_f16(gh0145, gh0145, 6));
efgh15 = vbslq_f16(shuffle_mask, vextq_f16(ef0145, ef0145, 2), gh0145);

Expand Down Expand Up @@ -197,14 +196,17 @@ static void transpose_kernel_mxn_neon_256(unsigned int N, const __fp16 *src,
vbslq_f16(shuffle_mask, vextq_f16(temp[4 * i + 1], temp[4 * i + 1], 2),
temp[4 * i + 3]);
}
bitmask_v8 = vld1q_u16(
reinterpret_cast<const uint16_t *>(neon_16bit_masks[M]));
bitmask_v8 =
vld1q_u16(reinterpret_cast<const uint16_t *>(neon_16bit_masks[M]));
for (i = 0; i < N; ++i) {
if (i < 4) {
temp[i] = vcombine_f16(vget_low_f16(input[i]), vget_low_f16(input[4 + i]));
temp[i] =
vcombine_f16(vget_low_f16(input[i]), vget_low_f16(input[4 + i]));
} else {
temp[i] = vcombine_f16(vget_high_f16(input[i - 4]), vget_high_f16(input[i]));
temp[i] =
vcombine_f16(vget_high_f16(input[i - 4]), vget_high_f16(input[i]));
}
vst1q_f16(&dst[i * ld_dst], vbslq_f16(bitmask_v8, temp[i], ZEROS));
vst1q_f16(&dst[i * ld_dst],
vbslq_f16(bitmask_v8, temp[i], vld1q_f16(&dst[i * ld_dst])));
}
}
214 changes: 81 additions & 133 deletions nntrainer/tensor/matrix_transpose_neon/matrix_transpose_neon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,164 +12,112 @@
*/

#include <arm_neon.h>
#include "./transpose_utils_neon.h"
#include "./matrix_transpose_neon.h"
#include <matrix_transpose_kernels_neon.h>
#include <matrix_transpose_neon.h>

template <>
void transpose_fallback(
unsigned int M,
unsigned int N,
const __fp16* src,
unsigned int ld_src,
__fp16* dst,
unsigned int ld_dst) {
for (unsigned int j = 0; j < N; j++) {
for (unsigned int i = 0; i < M; i++) {
dst[i + j * ld_dst] = src[i * ld_src + j];
}
}
}

template <>
void transpose_neon(
unsigned int M,
unsigned int N,
const __fp16* src,
unsigned int ld_src,
__fp16* dst,
unsigned int ld_dst) {
void transpose_neon(unsigned int M, unsigned int N, const __fp16 *src,
unsigned int ld_src, __fp16 *dst, unsigned int ld_dst) {
unsigned int ib = 0, jb = 0;
if (N % 8 > 0 && N % 8 < 4) {
for (ib = 0; ib + 8 <= M; ib += 8) {
for (jb = 0; jb + 8 <= N; jb += 8) {
transpose_kernel_8x8_neon(
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
transpose_kernel_8x8_neon(&src[ib * ld_src + jb], ld_src,
&dst[ib + jb * ld_dst], ld_dst);
}
for (unsigned int i = ib; i < ib + 8; i += 4) {
transpose_kernel_mxn_neon_128<4>(
N - jb,
&src[i * ld_src + jb],
ld_src,
&dst[i + jb * ld_dst],
ld_dst);
transpose_kernel_mxn_neon_128<4>(N - jb, &src[i * ld_src + jb], ld_src,
&dst[i + jb * ld_dst], ld_dst);
}
}
} else if (N % 8 == 4) {
for (ib = 0; ib + 8 <= M; ib += 8) {
for (jb = 0; jb + 8 <= N; jb += 8) {
transpose_kernel_8x8_neon(
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
transpose_kernel_8x8_neon(&src[ib * ld_src + jb], ld_src,
&dst[ib + jb * ld_dst], ld_dst);
}
for (unsigned int i = ib; i < ib + 8; i += 4) {
transpose_kernel_4x4_neon(
&src[i * ld_src + jb], ld_src, &dst[i + jb * ld_dst], ld_dst);
transpose_kernel_4x4_neon(&src[i * ld_src + jb], ld_src,
&dst[i + jb * ld_dst], ld_dst);
}
}
} else {
for (ib = 0; ib + 8 <= M; ib += 8) {
for (jb = 0; jb + 8 <= N; jb += 8) {
transpose_kernel_8x8_neon(
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
transpose_kernel_8x8_neon(&src[ib * ld_src + jb], ld_src,
&dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_neon_256<8>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
transpose_kernel_mxn_neon_256<8>(N - jb, &src[ib * ld_src + jb], ld_src,
&dst[ib + jb * ld_dst], ld_dst);
}
}
}
switch (M - ib) {
case 1:
for (unsigned int j = 0; j < N; ++j) {
dst[ib + j * ld_dst] = src[ib * ld_src + j];
}
break;
case 2:
for (jb = 0; jb + 4 <= N; jb += 4) {
transpose_kernel_mxn_neon_128<2>(
4, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_neon_128<2>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 3:
for (jb = 0; jb + 4 <= N; jb += 4) {
transpose_kernel_mxn_neon_128<3>(
4, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_neon_128<3>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 4:
for (jb = 0; jb + 4 <= N; jb += 4) {
transpose_kernel_4x4_neon(
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_neon_128<4>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 5:
for (jb = 0; jb + 8 <= N; jb += 8) {
transpose_kernel_mxn_neon_256<5>(
8, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_neon_256<5>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 6:
for (jb = 0; jb + 8 <= N; jb += 8) {
transpose_kernel_mxn_neon_256<6>(
8, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_neon_256<6>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 7:
for (jb = 0; jb + 8 <= N; jb += 8) {
transpose_kernel_mxn_neon_256<7>(
8, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_neon_256<7>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 1:
for (unsigned int j = 0; j < N; ++j) {
dst[ib + j * ld_dst] = src[ib * ld_src + j];
}
break;
case 2:
for (jb = 0; jb + 4 <= N; jb += 4) {
transpose_kernel_mxn_neon_128<2>(4, &src[ib * ld_src + jb], ld_src,
&dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_neon_128<2>(N - jb, &src[ib * ld_src + jb], ld_src,
&dst[ib + jb * ld_dst], ld_dst);
}
break;
case 3:
for (jb = 0; jb + 4 <= N; jb += 4) {
transpose_kernel_mxn_neon_128<3>(4, &src[ib * ld_src + jb], ld_src,
&dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_neon_128<3>(N - jb, &src[ib * ld_src + jb], ld_src,
&dst[ib + jb * ld_dst], ld_dst);
}
break;
case 4:
for (jb = 0; jb + 4 <= N; jb += 4) {
transpose_kernel_4x4_neon(&src[ib * ld_src + jb], ld_src,
&dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_neon_128<4>(N - jb, &src[ib * ld_src + jb], ld_src,
&dst[ib + jb * ld_dst], ld_dst);
}
break;
case 5:
for (jb = 0; jb + 8 <= N; jb += 8) {
transpose_kernel_mxn_neon_256<5>(8, &src[ib * ld_src + jb], ld_src,
&dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_neon_256<5>(N - jb, &src[ib * ld_src + jb], ld_src,
&dst[ib + jb * ld_dst], ld_dst);
}
break;
case 6:
for (jb = 0; jb + 8 <= N; jb += 8) {
transpose_kernel_mxn_neon_256<6>(8, &src[ib * ld_src + jb], ld_src,
&dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_neon_256<6>(N - jb, &src[ib * ld_src + jb], ld_src,
&dst[ib + jb * ld_dst], ld_dst);
}
break;
case 7:
for (jb = 0; jb + 8 <= N; jb += 8) {
transpose_kernel_mxn_neon_256<7>(8, &src[ib * ld_src + jb], ld_src,
&dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_neon_256<7>(N - jb, &src[ib * ld_src + jb], ld_src,
&dst[ib + jb * ld_dst], ld_dst);
}
break;
}
}
Loading

0 comments on commit 4efa98b

Please sign in to comment.