Skip to content

Commit

Permalink
blas/gemm: Update blas/gemm fp16 and fp8 implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
GiannaP committed Jul 26, 2023
1 parent 2e2217a commit bb30b81
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 47 deletions.
66 changes: 57 additions & 9 deletions sw/blas/gemm/data/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0

# Author: Tim Fischer <[email protected]>
# Authors: Tim Fischer <[email protected]>
# Luca Bertaccini <[email protected]>

import numpy as np
import argparse
Expand All @@ -15,20 +16,32 @@
C_TYPES = {
'64': 'double',
'32': 'float',
'16': '__fp16'
'16': '__fp16',
'8': 'char'
}

NUMPY_TYPES = {
'64': np.double,
'32': np.single,
'16': np.half
'16': np.half,
'8': np.ubyte
}

FP8_FORMATS = {
'fp8': {'exp': 5, 'mant': 2},
'fp8alt': {'exp': 4, 'mant': 3}
}


def format_vector_definition(id, vector, typ):
s = f'{typ} {id}[{len(vector)}] = ' + '{\n'
for i, el in enumerate(vector):
s += f'\t{el},'
if typ != 'char':
s += f'\t{el},'
else:
if type(el) == float:
print(el)
s += f'0x{el:02x},'
if i % 8 == 7:
s += '\n'
s += '};'
Expand Down Expand Up @@ -58,10 +71,40 @@ def emit_gemm_data(**kwargs):

# Generate random input matrices
dtype = NUMPY_TYPES[str(kwargs['prec'])]
a = np.random.rand(kwargs['M'], kwargs['K']).astype(dtype)
b = np.random.rand(kwargs['K'], kwargs['N']).astype(dtype)
c = np.random.rand(kwargs['M'], kwargs['N']).astype(dtype)
result = np.matmul(a, b) + kwargs['alpha'] * c
if (kwargs['prec']) == 8:
# sign -1 or 1
sign_a = np.random.randint(0, 2, (kwargs['M'], kwargs['K'])).astype(dtype)
# esponent < 0b01111
exponent_a = np.random.randint(0, 16, (kwargs['M'], kwargs['K'])).astype(dtype)
# mantissa can be arbitrary
mantissa_a = np.random.randint(0, 4, (kwargs['M'], kwargs['K'])).astype(dtype)
# sign -1 or 1
sign_b = np.random.randint(0, 2, (kwargs['K'], kwargs['N'])).astype(dtype)
# esponent < 0b01111
exponent_b = np.random.randint(0, 16, (kwargs['K'], kwargs['N'])).astype(dtype)
# mantissa can be arbitrary
mantissa_b = np.random.randint(0, 4, (kwargs['K'], kwargs['N'])).astype(dtype)
# sign -1 or 1
sign_c = np.random.randint(0, 2, (kwargs['M'], kwargs['N'])).astype(dtype)
# esponent < 0b01111
exponent_c = np.random.randint(0, 16, (kwargs['M'], kwargs['N'])).astype(dtype)
# mantissa can be arbitrary
mantissa_c = np.random.randint(0, 4, (kwargs['M'], kwargs['N'])).astype(dtype)
_a = ((-1.0)**sign_a.astype(np.double))*(2.0**(exponent_a.astype(np.double)-15.0)) \
* (1.0 + mantissa_a.astype(np.double) / (2**2))
_b = ((-1.0)**sign_b.astype(np.double))*(2.0**(exponent_b.astype(np.double)-15.0)) \
* (1.0 + mantissa_b.astype(np.double) / (2**2))
_c = ((-1.0)**sign_c.astype(np.double))*(2.0**(exponent_c.astype(np.double)-15.0)) \
* (1.0 + mantissa_c.astype(np.double) / (2**2))
result = np.matmul(_a, _b) + kwargs['alpha'] * _c
a = sign_a << 7 | exponent_a << FP8_FORMATS['fp8']['mant'] | mantissa_a
b = sign_b << 7 | exponent_b << FP8_FORMATS['fp8']['mant'] | mantissa_b
c = sign_c << 7 | exponent_c << FP8_FORMATS['fp8']['mant'] | mantissa_c
else:
a = np.random.rand(kwargs['M'], kwargs['K']).astype(dtype)
b = np.random.rand(kwargs['K'], kwargs['N']).astype(dtype)
c = np.random.rand(kwargs['M'], kwargs['N']).astype(dtype)
result = np.matmul(a, b) + kwargs['alpha'] * c

# Store matrices in transposed form if requested
a = a.T if kwargs['ta'] else a
Expand All @@ -79,7 +122,12 @@ def emit_gemm_data(**kwargs):
data_str += [format_vector_definition('a', a.flatten(), C_TYPES[str(kwargs['prec'])])]
data_str += [format_vector_definition('b', b.flatten(), C_TYPES[str(kwargs['prec'])])]
data_str += [format_vector_definition('c', c.flatten(), C_TYPES[str(kwargs['prec'])])]
data_str += [format_vector_definition('result', result.flatten(), C_TYPES[str(kwargs['prec'])])]
if kwargs['prec'] == 8:
data_str += [format_vector_definition('result', result.flatten(), C_TYPES['64'])]
else:
data_str += [format_vector_definition('result',
result.flatten(),
C_TYPES[str(kwargs['prec'])])]
data_str = '\n\n'.join(data_str)

return data_str
Expand Down
175 changes: 137 additions & 38 deletions sw/blas/gemm/src/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
// SPDX-License-Identifier: Apache-2.0
//
// Author: Tim Fischer <[email protected]>
// Luca Bertaccini <[email protected]>

#include <stdint.h>

#include "snrt.h"

typedef float v2f32 __attribute__((vector_size(8)));
Expand Down Expand Up @@ -383,23 +383,41 @@ void gemm_fp16_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A, uint32_t ldA,
"lw %[alpha], 0(%[ALPHA]) \n"
"beqz %[alpha], 1f \n"
// Load intermediate results
"flw %[c0], 0(%[C]) \n"
"flw %[c1], 4(%[C]) \n"
"flw %[c2], 8(%[C]) \n"
"flw %[c3], 12(%[C]) \n"
"flw %[c4], 16(%[C]) \n"
"flw %[c5], 20(%[C]) \n"
"flw %[c6], 24(%[C]) \n"
"flw %[c7], 28(%[C]) \n"
"flh %[reduce_reg0], 0(%[C]) \n"
"flh %[reduce_reg1], 2(%[C]) \n"
"flh %[reduce_reg2], 4(%[C]) \n"
"flh %[reduce_reg3], 6(%[C]) \n"
"flh %[reduce_reg4], 8(%[C]) \n"
"flh %[reduce_reg5], 10(%[C]) \n"
"flh %[reduce_reg6], 12(%[C]) \n"
"flh %[reduce_reg7], 14(%[C]) \n"
// Convert intermediate results before packing
"vfcvt.s.h %[reduce_reg0], %[reduce_reg0]\n"
"vfcvt.s.h %[reduce_reg1], %[reduce_reg1]\n"
"vfcvt.s.h %[reduce_reg2], %[reduce_reg2]\n"
"vfcvt.s.h %[reduce_reg3], %[reduce_reg3]\n"
"vfcvt.s.h %[reduce_reg4], %[reduce_reg4]\n"
"vfcvt.s.h %[reduce_reg5], %[reduce_reg5]\n"
"vfcvt.s.h %[reduce_reg6], %[reduce_reg6]\n"
"vfcvt.s.h %[reduce_reg7], %[reduce_reg7]\n"
// Initialize reduce register to zero
"vfcpka.s.s %[c0], %[zero], %[zero]\n"
"vfcpka.s.s %[c1], %[zero], %[zero]\n"
"vfcpka.s.s %[c2], %[zero], %[zero]\n"
"vfcpka.s.s %[c3], %[zero], %[zero]\n"
"vfcpka.s.s %[c4], %[zero], %[zero]\n"
"vfcpka.s.s %[c5], %[zero], %[zero]\n"
"vfcpka.s.s %[c6], %[zero], %[zero]\n"
"vfcpka.s.s %[c7], %[zero], %[zero]\n"
// Pack intermediate results into SIMD vector
"vfcpka.s.s %[c0], %[c0], %[zero]\n"
"vfcpka.s.s %[c1], %[c1], %[zero]\n"
"vfcpka.s.s %[c2], %[c2], %[zero]\n"
"vfcpka.s.s %[c3], %[c3], %[zero]\n"
"vfcpka.s.s %[c4], %[c4], %[zero]\n"
"vfcpka.s.s %[c5], %[c5], %[zero]\n"
"vfcpka.s.s %[c6], %[c6], %[zero]\n"
"vfcpka.s.s %[c7], %[c7], %[zero]\n"
"vfcpka.h.s %[c0], %[reduce_reg0], %[zero]\n"
"vfcpka.h.s %[c1], %[reduce_reg1], %[zero]\n"
"vfcpka.h.s %[c2], %[reduce_reg2], %[zero]\n"
"vfcpka.h.s %[c3], %[reduce_reg3], %[zero]\n"
"vfcpka.h.s %[c4], %[reduce_reg4], %[zero]\n"
"vfcpka.h.s %[c5], %[reduce_reg5], %[zero]\n"
"vfcpka.h.s %[c6], %[reduce_reg6], %[zero]\n"
"vfcpka.h.s %[c7], %[reduce_reg7], %[zero]\n"
"j 2f \n"
"1: \n"
// Initialize SIMD vector with zeros
Expand Down Expand Up @@ -548,24 +566,41 @@ void gemm_fp16_ex_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A,
asm volatile(
"lw %[alpha], 0(%[ALPHA]) \n"
"beqz %[alpha], 1f \n"
// Load intermediate results
"flw %[c0], 0(%[C]) \n"
"flw %[c1], 4(%[C]) \n"
"flw %[c2], 8(%[C]) \n"
"flw %[c3], 12(%[C]) \n"
"flw %[c4], 16(%[C]) \n"
"flw %[c5], 20(%[C]) \n"
"flw %[c6], 24(%[C]) \n"
"flw %[c7], 28(%[C]) \n"
"flh %[reduce_reg0], 0(%[C]) \n"
"flh %[reduce_reg1], 2(%[C]) \n"
"flh %[reduce_reg2], 4(%[C]) \n"
"flh %[reduce_reg3], 6(%[C]) \n"
"flh %[reduce_reg4], 8(%[C]) \n"
"flh %[reduce_reg5], 10(%[C]) \n"
"flh %[reduce_reg6], 12(%[C]) \n"
"flh %[reduce_reg7], 14(%[C]) \n"
// Convert intermediate results before packing
"vfcvt.s.h %[reduce_reg0], %[reduce_reg0]\n"
"vfcvt.s.h %[reduce_reg1], %[reduce_reg1]\n"
"vfcvt.s.h %[reduce_reg2], %[reduce_reg2]\n"
"vfcvt.s.h %[reduce_reg3], %[reduce_reg3]\n"
"vfcvt.s.h %[reduce_reg4], %[reduce_reg4]\n"
"vfcvt.s.h %[reduce_reg5], %[reduce_reg5]\n"
"vfcvt.s.h %[reduce_reg6], %[reduce_reg6]\n"
"vfcvt.s.h %[reduce_reg7], %[reduce_reg7]\n"
// Initialize reduce register to zero
"vfcpka.s.s %[c0], %[zero], %[zero]\n"
"vfcpka.s.s %[c1], %[zero], %[zero]\n"
"vfcpka.s.s %[c2], %[zero], %[zero]\n"
"vfcpka.s.s %[c3], %[zero], %[zero]\n"
"vfcpka.s.s %[c4], %[zero], %[zero]\n"
"vfcpka.s.s %[c5], %[zero], %[zero]\n"
"vfcpka.s.s %[c6], %[zero], %[zero]\n"
"vfcpka.s.s %[c7], %[zero], %[zero]\n"
// Pack intermediate results into SIMD vector
"vfcpka.s.s %[c0], %[c0], %[zero]\n"
"vfcpka.s.s %[c1], %[c1], %[zero]\n"
"vfcpka.s.s %[c2], %[c2], %[zero]\n"
"vfcpka.s.s %[c3], %[c3], %[zero]\n"
"vfcpka.s.s %[c4], %[c4], %[zero]\n"
"vfcpka.s.s %[c5], %[c5], %[zero]\n"
"vfcpka.s.s %[c6], %[c6], %[zero]\n"
"vfcpka.s.s %[c7], %[c7], %[zero]\n"
"vfcpka.s.s %[c0], %[reduce_reg0], %[zero]\n"
"vfcpka.s.s %[c1], %[reduce_reg1], %[zero]\n"
"vfcpka.s.s %[c2], %[reduce_reg2], %[zero]\n"
"vfcpka.s.s %[c3], %[reduce_reg3], %[zero]\n"
"vfcpka.s.s %[c4], %[reduce_reg4], %[zero]\n"
"vfcpka.s.s %[c5], %[reduce_reg5], %[zero]\n"
"vfcpka.s.s %[c6], %[reduce_reg6], %[zero]\n"
"vfcpka.s.s %[c7], %[reduce_reg7], %[zero]\n"
"j 2f \n"
"1: \n"
// Initialize SIMD vector with zeros
Expand Down Expand Up @@ -696,6 +731,45 @@ void gemm_fp8_ex_opt(uint32_t M, uint32_t N, uint32_t K, char* A, uint32_t ldA,
uint32_t alpha;

asm volatile(
"lw %[alpha], 0(%[ALPHA]) \n"
"beqz %[alpha], 1f \n"
"flb %[reduce_reg0], 0(%[C]) \n"
"flb %[reduce_reg1], 1(%[C]) \n"
"flb %[reduce_reg2], 2(%[C]) \n"
"flb %[reduce_reg3], 3(%[C]) \n"
"flb %[reduce_reg4], 4(%[C]) \n"
"flb %[reduce_reg5], 5(%[C]) \n"
"flb %[reduce_reg6], 6(%[C]) \n"
"flb %[reduce_reg7], 7(%[C]) \n"
// Convert intermediate results before packing
"vfcvt.s.b %[reduce_reg0], %[reduce_reg0]\n"
"vfcvt.s.b %[reduce_reg1], %[reduce_reg1]\n"
"vfcvt.s.b %[reduce_reg2], %[reduce_reg2]\n"
"vfcvt.s.b %[reduce_reg3], %[reduce_reg3]\n"
"vfcvt.s.b %[reduce_reg4], %[reduce_reg4]\n"
"vfcvt.s.b %[reduce_reg5], %[reduce_reg5]\n"
"vfcvt.s.b %[reduce_reg6], %[reduce_reg6]\n"
"vfcvt.s.b %[reduce_reg7], %[reduce_reg7]\n"
// Initialize reduce register to zero
"vfcpka.s.s %[c0], %[zero], %[zero]\n"
"vfcpka.s.s %[c1], %[zero], %[zero]\n"
"vfcpka.s.s %[c2], %[zero], %[zero]\n"
"vfcpka.s.s %[c3], %[zero], %[zero]\n"
"vfcpka.s.s %[c4], %[zero], %[zero]\n"
"vfcpka.s.s %[c5], %[zero], %[zero]\n"
"vfcpka.s.s %[c6], %[zero], %[zero]\n"
"vfcpka.s.s %[c7], %[zero], %[zero]\n"
// Pack intermediate results into SIMD vector
"vfcpka.h.s %[c0], %[reduce_reg0], %[zero]\n"
"vfcpka.h.s %[c1], %[reduce_reg1], %[zero]\n"
"vfcpka.h.s %[c2], %[reduce_reg2], %[zero]\n"
"vfcpka.h.s %[c3], %[reduce_reg3], %[zero]\n"
"vfcpka.h.s %[c4], %[reduce_reg4], %[zero]\n"
"vfcpka.h.s %[c5], %[reduce_reg5], %[zero]\n"
"vfcpka.h.s %[c6], %[reduce_reg6], %[zero]\n"
"vfcpka.h.s %[c7], %[reduce_reg7], %[zero]\n"
"j 2f \n"
"1: \n"
// Initialize SIMD vector with zeros
"vfcpka.s.s %[c0], %[zero], %[zero]\n"
"vfcpka.s.s %[c1], %[zero], %[zero]\n"
Expand All @@ -705,6 +779,7 @@ void gemm_fp8_ex_opt(uint32_t M, uint32_t N, uint32_t K, char* A, uint32_t ldA,
"vfcpka.s.s %[c5], %[zero], %[zero]\n"
"vfcpka.s.s %[c6], %[zero], %[zero]\n"
"vfcpka.s.s %[c7], %[zero], %[zero]\n"
"2: \n"
// Perform expanding sum-dotproducts
"frep.o %[n_frep], 8, 0, 0 \n"
"vfdotpex.h.b %[c0], ft1, ft0 \n"
Expand Down Expand Up @@ -733,11 +808,35 @@ void gemm_fp8_ex_opt(uint32_t M, uint32_t N, uint32_t K, char* A, uint32_t ldA,
"vfsumex.s.h %[reduce_reg5], %[c5] \n"
"vfsumex.s.h %[reduce_reg6], %[c6] \n"
"vfsumex.s.h %[reduce_reg7], %[c7] \n"
//
// Initialize reduce register to zero
"vfcpka.s.s %[c0], %[zero], %[zero] \n"
"vfcpka.s.s %[c1], %[zero], %[zero] \n"
"vfcpka.s.s %[c2], %[zero], %[zero] \n"
"vfcpka.s.s %[c3], %[zero], %[zero] \n"
"vfcpka.s.s %[c4], %[zero], %[zero] \n"
"vfcpka.s.s %[c5], %[zero], %[zero] \n"
"vfcpka.s.s %[c6], %[zero], %[zero] \n"
"vfcpka.s.s %[c7], %[zero], %[zero] \n"
// Sum-reduce vector
"vfsum.s %[c0], %[reduce_reg0] \n"
"vfsum.s %[c1], %[reduce_reg1] \n"
"vfsum.s %[c2], %[reduce_reg2] \n"
"vfsum.s %[c3], %[reduce_reg3] \n"
"vfsum.s %[c4], %[reduce_reg4] \n"
"vfsum.s %[c5], %[reduce_reg5] \n"
"vfsum.s %[c6], %[reduce_reg6] \n"
"vfsum.s %[c7], %[reduce_reg7] \n"
// Pack and convert results to FP8 vectors
"vfcpka.b.s %[c0], %[reduce_reg0], %[reduce_reg1] \n"
"vfcpkb.b.s %[c0], %[reduce_reg2], %[reduce_reg3] \n"
"vfcpkc.b.s %[c0], %[reduce_reg4], %[reduce_reg5] \n"
"vfcpkd.b.s %[c0], %[reduce_reg6], %[reduce_reg7] \n"
"vfcpka.b.s %[c0], %[c0], %[c1] \n"
"vfcpkb.b.s %[c0], %[c2], %[c3] \n"
"vfcpkc.b.s %[c0], %[c4], %[c5] \n"
"vfcpkd.b.s %[c0], %[c6], %[c7] \n"
// // // Pack and convert results to FP8 vectors
// "vfcpka.b.s %[c0], %[reduce_reg0], %[reduce_reg1] \n"
// "vfcpkb.b.s %[c0], %[reduce_reg2], %[reduce_reg3] \n"
// "vfcpkc.b.s %[c0], %[reduce_reg4], %[reduce_reg5] \n"
// "vfcpkd.b.s %[c0], %[reduce_reg6], %[reduce_reg7] \n"
: [ c0 ] "+f"(c[0]), [ c1 ] "+f"(c[1]), [ c2 ] "+f"(c[2]),
[ c3 ] "+f"(c[3]), [ c4 ] "+f"(c[4]), [ c5 ] "+f"(c[5]),
[ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7]), [ alpha ] "=r"(alpha),
Expand Down

0 comments on commit bb30b81

Please sign in to comment.