Skip to content

Commit

Permalink
[MLP] Add MLP layer and fused linear gelu (FLG)
Browse files Browse the repository at this point in the history
  • Loading branch information
Viviane Potocnik committed Jul 11, 2024
1 parent f3e942b commit 9369f3d
Show file tree
Hide file tree
Showing 15 changed files with 1,502 additions and 0 deletions.
25 changes: 25 additions & 0 deletions sw/dnn/fused_linear_gelu/data/params.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright 2024 ETH Zurich and University of Bologna.
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0

{
setup_ssr: 1,
parallelize_m: 0,
parallelize_k: 0,
m_tiles: 2, // number of tiles in M dimension
n_tiles: 1, // number of tiles in N dimension
k_tiles: 1, // number of tiles in K dimension
load_a: 1,
load_b: 1,
load_c: 1,
transa: false,
transb: true, // must be true for SIMD
M: 16,
N: 16,
K: 16,
alpha: 1,
beta: 0,
a_gelu: -0.2888,
b_gelu: -1.769,
flg_fp: "fused_linear_gelu_fp64_naive"
}
165 changes: 165 additions & 0 deletions sw/dnn/fused_linear_gelu/scripts/datagen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
#!/usr/bin/env python3
# Copyright 2022 ETH Zurich and University of Bologna.
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0
#
# Authors: Viviane Potocnik <[email protected]>

import numpy as np
import os
import re
import pyflexfloat as ff
import sys
import torch

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../util/sim/"))
import data_utils # noqa: E402
from data_utils import DataGen, format_array_declaration, format_struct_definition, \
format_array_definition, format_ifdef_wrapper # noqa: E402


np.random.seed(42)

def sigmoid_gelu(x, x_shape, a, b):
# print(dir(ff.np))
# print(type(x))
x = ff.np.float32(x)
# reshape and convert back to torch tensor
# x = x.reshape(x_shape)
x = torch.tensor(x)
result = torch.sign(x) * (a * (torch.clamp(torch.abs(x), max=-b) + b)**2 + 1)
# return as same type as input
return ff.array(result.numpy(), x_shape)


class FlgDataGen(DataGen):

# AXI splits bursts crossing 4KB address boundaries. To minimize
# the occurrence of these splits the data should be aligned to 4KB
BURST_ALIGNMENT = 4096

def exact_golden_model(self, alpha, a, b, beta, c, a_gelu, b_gelu):
M, N, K = a.shape[0], b.shape[1], b.shape[0]
result = beta * c
result_shape = result.shape
for m in range(M):
for n in range(N):
for k in range(K):
result[m][n] += a[m][k] * b[k][n]
result[m][n] = sigmoid_gelu(result[m][n], result_shape, a_gelu, b_gelu)
return result

def infer_implementation(self, flg_fp):
# flg_fp: "fused_linear_gelu_fp64_opt"
# create a regex with fp_<type>_<implementation>
prec, impl = re.search(r'fused_linear_gelu_fp(\d+)_(\w+)', flg_fp).group(1, 2)
return (int(prec) / 8), impl

def validate_config(self, flg_fp, parallelize_m,
parallelize_k, m_tiles, n_tiles, k_tiles, transa,
transb, M, N, K, beta, **kwargs):
frac_m = M / m_tiles
frac_n = N / n_tiles
frac_k = K / k_tiles

dtype, impl = self.infer_implementation(flg_fp)

# Calculate total TCDM occupation
# Note: doesn't account for double buffering
prec = data_utils.size_from_precision_t(dtype)
a_size = frac_m * frac_k * prec
b_size = frac_k * frac_n * prec
c_size = frac_m * frac_n * prec
total_size = a_size
total_size += b_size
total_size += c_size
data_utils.validate_tcdm_footprint(total_size)

assert (M % m_tiles) == 0, 'M is not an integer multiple of tile size'
assert (N % n_tiles) == 0, 'N is not an integer multiple of tile size'
assert (K % k_tiles) == 0, 'K is not an integer multiple of tile size'
assert not (parallelize_m and parallelize_k), 'Cannot parallelize K and M simultaneously'
assert not transa, 'SIMD kernels don\'t support transposed A matrix'
assert (dtype == 8) or (impl == 'baseline') or (impl == 'naive') \
or transb, 'Optimized SIMD kernels only support transposed B matrix'
assert not transb or n_tiles == 1, 'Tiling in the N dimension not supported' \
' if B is transposed'
assert not transb or k_tiles == 1, 'Tiling in the K dimension not supported' \
' if B is transposed'
assert (impl == 'baseline') or (impl == 'naive') or frac_n >= 8, \
'N dimension of tile size must be greater or equal to the unrolling factor (8) ' \
'when using optimized kernels'
assert beta == 0 or beta == 1, 'Only values of 0 or 1 supported for beta'
assert not (dtype == 8 and impl == "baseline"), 'No baseline implemented' \
' for FP64 (switch to NAIVE)'
assert not (((dtype == 8) or (dtype == 4)) and impl == "opt_ex"), \
'Expanding GEMM kernels' \
' not supported for FP64 and FP32'
assert not (dtype == 1 and impl == "opt"), 'FP8 not supported in' \
' optimized implementation' \
' (switch to opt_ex)'

def emit_header(self, **kwargs):
header = [super().emit_header()]

# Validate parameters
self.validate_config(**kwargs)

M, N, K = kwargs['M'], kwargs['N'], kwargs['K']

prec, _ = self.infer_implementation(kwargs['flg_fp'])

ff_desc = data_utils.ff_desc_from_precision_t(prec)
ctype = data_utils.ctype_from_precision_t(prec)

a = ff.array(np.random.rand(M, K), ff_desc)
b = ff.array(np.random.rand(K, N), ff_desc)
c = ff.array(np.random.rand(M, N), ff_desc)

# a = -0.2888
# b = -1.769
a_gelu = kwargs['a_gelu']
b_gelu = kwargs['b_gelu']
result = self.exact_golden_model(1, a, b, kwargs['beta'], c, a_gelu, b_gelu)

# Store matrices in transposed form if requested
a = a.T if kwargs['transa'] else a
b = b.T if kwargs['transb'] else b

a_uid = 'a'
b_uid = 'b'
c_uid = 'c'

cfg = {
'prec': prec,
**kwargs,
'a_gelu': a_gelu,
'b_gelu': b_gelu,
'a': a_uid,
'b': b_uid,
'c': c_uid,
}

a = a.flatten()
b = b.flatten()
c = c.flatten()

header += [format_array_declaration(ctype, a_uid, a.shape)]
header += [format_array_declaration(ctype, b_uid, b.shape)]
header += [format_array_declaration(ctype, c_uid, c.shape)]
header += [format_struct_definition('flg_args_t', 'flg_args', cfg)]
header += [format_array_definition(ctype, a_uid, a,
section=kwargs['section'])]
header += [format_array_definition(ctype, b_uid, b,
section=kwargs['section'])]
header += [format_array_definition(ctype, c_uid, c,
section=kwargs['section'])]
result_def = format_array_definition(ctype, 'result', result.flatten())
header += [format_ifdef_wrapper('BIST', result_def)]
header = '\n\n'.join(header)

return header


if __name__ == "__main__":
sys.exit(FlgDataGen().main())
Loading

0 comments on commit 9369f3d

Please sign in to comment.