Skip to content

Commit

Permalink
#EDITS: some mass updated to linear algebra module and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
akielaries committed Feb 27, 2024
1 parent d0465af commit 24091cb
Showing 1 changed file with 3 additions and 104 deletions.
107 changes: 3 additions & 104 deletions modules/linalg/dgemm_arr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
#include <stdlib.h>
#include <time.h>

#if defined(__SSE__)
#if defined(__SSE2__)

#ifdef __cplusplus
extern "C" {
Expand Down Expand Up @@ -168,33 +168,7 @@ void gpmp::linalg::DGEMM::pack_buffer_B(int kc,
}
}

// micro kernel that multiplies panels from A and B using assembly kernels
void gpmp::linalg::DGEMM::dgemm_micro_kernel(long kc,
double alpha,
const double *A,
const double *B,
double beta,
double *C,
long incRowC,
long incColC,
const double *nextA,
const double *nextB) {
long kb = kc / 4;
long kl = kc % 4;

dgemm_kernel_asm(A,
B,
C,
nextA,
nextB,
kl,
kb,
incRowC,
incColC,
alpha,
beta);
}

// micro kernel that multiplies panels from A and B
void gpmp::linalg::DGEMM::dgemm_micro_kernel(int kc,
double alpha,
const double *A,
Expand Down Expand Up @@ -333,77 +307,6 @@ void gpmp::linalg::DGEMM::dgemm_macro_kernel(int mc,
int mr, nr;
int i, j;

// use assembly kernel function
#if defined(__SSE__)
const double *nextA;
const double *nextB;

for (j = 0; j < np; ++j) {
nr = (j != np - 1 || _nr == 0) ? BLOCK_SZ_NR : _nr;
nextB = &DGEMM_BUFF_B[j * kc * BLOCK_SZ_NR];

for (i = 0; i < mp; ++i) {
mr = (i != mp - 1 || _mr == 0) ? BLOCK_SZ_MR : _mr;
nextA = &DGEMM_BUFF_A[(i + 1) * kc * BLOCK_SZ_MR];

if (i == mp - 1) {
nextA = DGEMM_BUFF_A;
nextB = &DGEMM_BUFF_B[(j + 1) * kc * BLOCK_SZ_NR];
if (j == np - 1) {
nextB = DGEMM_BUFF_B;
}
}

if (mr == BLOCK_SZ_MR && nr == BLOCK_SZ_NR) {

dgemm_micro_kernel(
kc,
alpha,
&DGEMM_BUFF_A[i * kc * BLOCK_SZ_MR],
&DGEMM_BUFF_B[j * kc * BLOCK_SZ_NR],
beta,
&C[i * BLOCK_SZ_MR * incRowC + j * BLOCK_SZ_NR * incColC],
incRowC,
incColC,
nextA,
nextB);
}

else {
dgemm_micro_kernel(kc,
alpha,
&DGEMM_BUFF_A[i * kc * BLOCK_SZ_MR],
&DGEMM_BUFF_B[j * kc * BLOCK_SZ_NR],
0.0,
DGEMM_BUFF_C,
1,
BLOCK_SZ_MR,
nextA,
nextB);
dgescal(
mr,
nr,
beta,
&C[i * BLOCK_SZ_MR * incRowC + j * BLOCK_SZ_NR * incColC],
incRowC,
incColC);
dgeaxpy(mr,
nr,
1.0,
DGEMM_BUFF_C,
1,
BLOCK_SZ_MR,
&DGEMM_BUFF_C[i * BLOCK_SZ_MR * incRowC +
j * BLOCK_SZ_NR * incColC],
incRowC,
incColC);
}
}
}

// default implementation
#else

for (j = 0; j < np; ++j) {
nr = (j != np - 1 || _nr == 0) ? BLOCK_SZ_NR : _nr;

Expand All @@ -420,9 +323,7 @@ void gpmp::linalg::DGEMM::dgemm_macro_kernel(int mc,
&C[i * BLOCK_SZ_MR * incRowC + j * BLOCK_SZ_NR * incColC],
incRowC,
incColC);
}

else {
} else {
dgemm_micro_kernel(kc,
alpha,
&DGEMM_BUFF_A[i * kc * BLOCK_SZ_MR],
Expand Down Expand Up @@ -451,8 +352,6 @@ void gpmp::linalg::DGEMM::dgemm_macro_kernel(int mc,
}
}
}

#endif
}

// Main DGEMM entrypoint, compute C <- beta*C + alpha*A*B
Expand Down

0 comments on commit 24091cb

Please sign in to comment.