Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ hgemm ] Partial sum up to 2048 digits for more acceleration & trivial refactor @open sesame 05/10 10:47 #2578

Merged
merged 8 commits into from
May 22, 2024
70 changes: 68 additions & 2 deletions nntrainer/tensor/hgemm/hgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ void hgemm_noTrans(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
hgemm_noTrans_4x8(M, N, K, A, K, B, N, C, N, alpha, beta);
} else if ((M & 0x3) == 0 && (N & 0x3) == 0 && (K & 0x3) == 0) {
hgemm_noTrans_4x4(M, N, K, A, K, B, N, C, N, alpha, beta);
} else if ((K & 0x7) == 0 && (N & 0x7) == 0) {
} else if ((N & 0x7) == 0 && (K & 0x7) == 0) {
hgemm_noTrans_1x8(M, N, K, A, K, B, N, C, N, alpha, beta);
} else if ((K & 0x7) == 0 && (N & 0x3) == 0) {
} else if ((N & 0x3) == 0 && (K & 0x7) == 0) {
hgemm_noTrans_1x4(M, N, K, A, K, B, N, C, N, alpha, beta);
}
}
Expand Down Expand Up @@ -412,6 +412,72 @@ void hgemm_noTrans_1x8(unsigned int M, unsigned int N, unsigned int K,
free(sb);
}

void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
const __fp16 *A, unsigned int lda, const __fp16 *B,
unsigned int ldb, float *C, unsigned int ldc,
float alpha, float beta) {
__fp16 *sa = alignedMalloc(M * K);
__fp16 *sb = alignedMalloc(K * N);

unsigned int ms, mms, ns, ks;
unsigned int m_min, m2_min, n_min, k_min;
for (ms = 0; ms < M; ms += M_BLOCKING) {
m_min = M - ms;
if (m_min > M_BLOCKING) {
m_min = M_BLOCKING;
}

for (ks = 0; ks < K; ks += k_min) {
k_min = K - ks;
if (k_min >= (K_BLOCKING << 1)) {
k_min = K_BLOCKING;
} else if (k_min > K_BLOCKING) {
k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
}

n_min = N;
if (N >= N_BLOCKING * 2) {
n_min = N_BLOCKING;
} else if (N > N_BLOCKING) {
n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
}
packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);

for (mms = ms; mms < ms + m_min; mms += m2_min) {
m2_min = (ms + m_min) - mms;
if (m2_min >= 3 * GEMM_UNROLLING_4) {
m2_min = 3 * GEMM_UNROLLING_4;
} else if (m2_min >= 2 * GEMM_UNROLLING_4) {
m2_min = 2 * GEMM_UNROLLING_4;
} else if (m2_min > GEMM_UNROLLING_4) {
m2_min = GEMM_UNROLLING_4;
}

packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
sa + k_min * (mms - ms));

HGEMM_KERNEL_4x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
C + mms * ldc, ldc);
}

for (ns = n_min; ns < N; ns += n_min) {
n_min = N - ns;
if (n_min >= N_BLOCKING * 2) {
n_min = N_BLOCKING;
} else if (n_min > N_BLOCKING) {
n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
}

packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
HGEMM_KERNEL_4x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
}
}
}

free(sa);
free(sb);
}

void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
const __fp16 *A, unsigned int lda, const __fp16 *B,
unsigned int ldb, __fp16 *C, unsigned int ldc,
Expand Down
20 changes: 20 additions & 0 deletions nntrainer/tensor/hgemm/hgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,26 @@ void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
unsigned int ldb, __fp16 *C, unsigned int ldc,
float alpha = 1.F, float beta = 0.F);

/**
* @brief hgemm noTrans computation with 4x4 kernel : C = A*B,
*
* @param M length of the row of matrix A
* @param N length of the col of matrix B
* @param K length of the col of matrix A
* @param A input matrix A
* @param lda length of the col of matrix C
* @param B input matrix B
* @param ldb length of the col of matrix C
* @param C output matrix C
* @param ldc length of the col of matrix C
* @param[in] alpha float number
* @param[in] beta float number
*/
void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
const __fp16 *A, unsigned int lda, const __fp16 *B,
unsigned int ldb, float *C, unsigned int ldc,
float alpha = 1.F, float beta = 0.F);

/**
* @brief hgemm noTrans computation with 8x8 kernel : C = A*B,
*
Expand Down
252 changes: 248 additions & 4 deletions nntrainer/tensor/hgemm/hgemm_kernel_4x4.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,193 @@
#include <hgemm_common.h>
#include <stdlib.h>

#define INIT_KERNEL_4x4() \
v24 = vdup_n_f16(0.F); \
v25 = vdup_n_f16(0.F); \
v26 = vdup_n_f16(0.F); \
v27 = vdup_n_f16(0.F);

// 1. Partial sum 256 digits
#define KERNEL_4x4_ACC16() \
dv0 = vld1_f16(a); \
vb0 = vld1_f16(b); \
v24 = vfma_lane_f16(v24, vb0, dv0, 0); \
v25 = vfma_lane_f16(v25, vb0, dv0, 1); \
v26 = vfma_lane_f16(v26, vb0, dv0, 2); \
v27 = vfma_lane_f16(v27, vb0, dv0, 3); \
dv1 = vld1_f16(a + 4); \
vb1 = vld1_f16(b + 4); \
v24 = vfma_lane_f16(v24, vb1, dv1, 0); \
v25 = vfma_lane_f16(v25, vb1, dv1, 1); \
v26 = vfma_lane_f16(v26, vb1, dv1, 2); \
v27 = vfma_lane_f16(v27, vb1, dv1, 3); \
dv2 = vld1_f16(a + 4 * 2); \
vb2 = vld1_f16(b + 4 * 2); \
v24 = vfma_lane_f16(v24, vb2, dv2, 0); \
v25 = vfma_lane_f16(v25, vb2, dv2, 1); \
v26 = vfma_lane_f16(v26, vb2, dv2, 2); \
v27 = vfma_lane_f16(v27, vb2, dv2, 3); \
dv3 = vld1_f16(a + 4 * 3); \
vb3 = vld1_f16(b + 4 * 3); \
v24 = vfma_lane_f16(v24, vb3, dv3, 0); \
v25 = vfma_lane_f16(v25, vb3, dv3, 1); \
v26 = vfma_lane_f16(v26, vb3, dv3, 2); \
v27 = vfma_lane_f16(v27, vb3, dv3, 3); \
dv4 = vld1_f16(a + 4 * 4); \
vb4 = vld1_f16(b + 4 * 4); \
v24 = vfma_lane_f16(v24, vb4, dv4, 0); \
v25 = vfma_lane_f16(v25, vb4, dv4, 1); \
v26 = vfma_lane_f16(v26, vb4, dv4, 2); \
v27 = vfma_lane_f16(v27, vb4, dv4, 3); \
dv5 = vld1_f16(a + 4 * 5); \
vb5 = vld1_f16(b + 4 * 5); \
v24 = vfma_lane_f16(v24, vb5, dv5, 0); \
v25 = vfma_lane_f16(v25, vb5, dv5, 1); \
v26 = vfma_lane_f16(v26, vb5, dv5, 2); \
v27 = vfma_lane_f16(v27, vb5, dv5, 3); \
dv6 = vld1_f16(a + 4 * 6); \
vb6 = vld1_f16(b + 4 * 6); \
v24 = vfma_lane_f16(v24, vb6, dv6, 0); \
v25 = vfma_lane_f16(v25, vb6, dv6, 1); \
v26 = vfma_lane_f16(v26, vb6, dv6, 2); \
v27 = vfma_lane_f16(v27, vb6, dv6, 3); \
dv7 = vld1_f16(a + 4 * 7); \
vb7 = vld1_f16(b + 4 * 7); \
v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
dv7 = vld1_f16(a + 4 * 8); \
vb7 = vld1_f16(b + 4 * 8); \
v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
dv7 = vld1_f16(a + 4 * 9); \
vb7 = vld1_f16(b + 4 * 9); \
v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
dv7 = vld1_f16(a + 4 * 10); \
vb7 = vld1_f16(b + 4 * 10); \
v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
dv7 = vld1_f16(a + 4 * 11); \
vb7 = vld1_f16(b + 4 * 11); \
v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
dv7 = vld1_f16(a + 4 * 12); \
vb7 = vld1_f16(b + 4 * 12); \
v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
dv7 = vld1_f16(a + 4 * 13); \
vb7 = vld1_f16(b + 4 * 13); \
v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
dv7 = vld1_f16(a + 4 * 14); \
vb7 = vld1_f16(b + 4 * 14); \
v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
dv7 = vld1_f16(a + 4 * 15); \
vb7 = vld1_f16(b + 4 * 15); \
v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
l += 16; \
__builtin_prefetch(b + 64, 0, 3); \
__builtin_prefetch(a + 64, 0, 3); \
b += 4 * 16; \
a += 4 * 16;

// 2. Partial sum 128 digits
#define KERNEL_4x4_ACC8() \
dv0 = vld1_f16(a); \
vb0 = vld1_f16(b); \
v24 = vfma_lane_f16(v24, vb0, dv0, 0); \
v25 = vfma_lane_f16(v25, vb0, dv0, 1); \
v26 = vfma_lane_f16(v26, vb0, dv0, 2); \
v27 = vfma_lane_f16(v27, vb0, dv0, 3); \
dv1 = vld1_f16(a + 4); \
vb1 = vld1_f16(b + 4); \
v24 = vfma_lane_f16(v24, vb1, dv1, 0); \
v25 = vfma_lane_f16(v25, vb1, dv1, 1); \
v26 = vfma_lane_f16(v26, vb1, dv1, 2); \
v27 = vfma_lane_f16(v27, vb1, dv1, 3); \
dv2 = vld1_f16(a + 8); \
vb2 = vld1_f16(b + 8); \
v24 = vfma_lane_f16(v24, vb2, dv2, 0); \
v25 = vfma_lane_f16(v25, vb2, dv2, 1); \
v26 = vfma_lane_f16(v26, vb2, dv2, 2); \
v27 = vfma_lane_f16(v27, vb2, dv2, 3); \
dv3 = vld1_f16(a + 12); \
vb3 = vld1_f16(b + 12); \
v24 = vfma_lane_f16(v24, vb3, dv3, 0); \
v25 = vfma_lane_f16(v25, vb3, dv3, 1); \
v26 = vfma_lane_f16(v26, vb3, dv3, 2); \
v27 = vfma_lane_f16(v27, vb3, dv3, 3); \
dv4 = vld1_f16(a + 16); \
vb4 = vld1_f16(b + 16); \
v24 = vfma_lane_f16(v24, vb4, dv4, 0); \
v25 = vfma_lane_f16(v25, vb4, dv4, 1); \
v26 = vfma_lane_f16(v26, vb4, dv4, 2); \
v27 = vfma_lane_f16(v27, vb4, dv4, 3); \
dv5 = vld1_f16(a + 20); \
vb5 = vld1_f16(b + 20); \
v24 = vfma_lane_f16(v24, vb5, dv5, 0); \
v25 = vfma_lane_f16(v25, vb5, dv5, 1); \
v26 = vfma_lane_f16(v26, vb5, dv5, 2); \
v27 = vfma_lane_f16(v27, vb5, dv5, 3); \
dv6 = vld1_f16(a + 24); \
vb6 = vld1_f16(b + 24); \
v24 = vfma_lane_f16(v24, vb6, dv6, 0); \
v25 = vfma_lane_f16(v25, vb6, dv6, 1); \
v26 = vfma_lane_f16(v26, vb6, dv6, 2); \
v27 = vfma_lane_f16(v27, vb6, dv6, 3); \
dv7 = vld1_f16(a + 28); \
vb7 = vld1_f16(b + 28); \
v24 = vfma_lane_f16(v24, vb7, dv7, 0); \
v25 = vfma_lane_f16(v25, vb7, dv7, 1); \
v26 = vfma_lane_f16(v26, vb7, dv7, 2); \
v27 = vfma_lane_f16(v27, vb7, dv7, 3); \
l += 8; \
__builtin_prefetch(b + 32, 0, 3); \
__builtin_prefetch(a + 32, 0, 3); \
b += 4 * 8; \
a += 4 * 8;

// 2. Partial sum 16 digits
#define KERNEL_4x4_ACC1() \
dv0 = vld1_f16(a); \
vb0 = vld1_f16(b); \
v24 = vfma_lane_f16(v24, vb0, dv0, 0); \
v25 = vfma_lane_f16(v25, vb0, dv0, 1); \
v26 = vfma_lane_f16(v26, vb0, dv0, 2); \
v27 = vfma_lane_f16(v27, vb0, dv0, 3); \
l += 1; \
__builtin_prefetch(b + 4, 0, 3); \
__builtin_prefetch(a + 4, 0, 3); \
b += 4 * 1; \
a += 4 * 1;

#define SAVE_KERNEL_4X4_F16_F32() \
vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(v24))); \
vst1q_f32(c + ldc, vaddq_f32(vld1q_f32(c + ldc), vcvt_f32_f16(v25))); \
vst1q_f32(c + 2 * ldc, \
vaddq_f32(vld1q_f32(c + 2 * ldc), vcvt_f32_f16(v26))); \
vst1q_f32(c + 3 * ldc, vaddq_f32(vld1q_f32(c + 3 * ldc), vcvt_f32_f16(v27)));

/**
* @brief hgemm 4x4 kernel sc = sa * sb
*
Expand All @@ -37,10 +224,11 @@ void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K,
__builtin_prefetch(b, 0, 3);
__builtin_prefetch(a, 0, 3);

float16x4_t v24 = {0};
float16x4_t v25 = {0};
float16x4_t v26 = {0};
float16x4_t v27 = {0};
float16x4_t v24;
float16x4_t v25;
float16x4_t v26;
float16x4_t v27;
INIT_KERNEL_4x4();

for (l = 0; l < K; l += VL_FP16_HALF) {
float16x4_t v0 = vld1_f16(b);
Expand Down Expand Up @@ -101,3 +289,59 @@ void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K,
b = sb;
}
}

/**
* @brief hgemm 4x4 kernel sc = sa * sb
*
* @param m length of the row of matrix A
* @param n length of the col of matrix B
* @param k length of the col of matrix A
* @param sa sub-matrix of input matrix A
* @param sb sub-matrix of input matrix B
* @param sc sub-matrix of output matrix C
* @param ldc leading dimension of matrix C
*/
void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K,
__fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
assert(M > 0 && N > 0 && K > 0);
assert(M % 4 == 0 && N % 4 == 0 && K % 4 == 0);

__fp16 *a = sa, *b = sb;
float *c = sc;
unsigned int i, j, l;
unsigned int K16 = (K >> 4) << 4;
unsigned int K8 = (K >> 3) << 3;
for (i = 0; i < M; i += VL_FP16_HALF) {
for (j = 0; j < N; j += VL_FP16_HALF) {
__builtin_prefetch(b, 0, 3);
__builtin_prefetch(a, 0, 3);

float16x4_t v24, v25, v26, v27;
float16x4_t dv0, dv1, dv2, dv3, dv4, dv5, dv6, dv7;
float16x4_t vb0, vb1, vb2, vb3, vb4, vb5, vb6, vb7;
l = 0;
for (; l < K16;) {
INIT_KERNEL_4x4();
KERNEL_4x4_ACC16();
SAVE_KERNEL_4X4_F16_F32();
}
for (; l < K8;) {
INIT_KERNEL_4x4();
KERNEL_4x4_ACC8();
SAVE_KERNEL_4X4_F16_F32();
}
for (; l < K;) {
INIT_KERNEL_4x4();
KERNEL_4x4_ACC1();
SAVE_KERNEL_4X4_F16_F32();
}

c += 4;
a -= 4 * K;
}
sc += ldc * 4;
c = sc;
a += 4 * K;
b = sb;
}
}
Loading