forked from pulp-platform/snitch_cluster
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MLP] Add MLP layer and fused linear gelu (FLG)
- Loading branch information
Viviane Potocnik
committed
Jul 11, 2024
1 parent
f3e942b
commit 9369f3d
Showing
15 changed files
with
1,502 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// Copyright 2024 ETH Zurich and University of Bologna. | ||
// Licensed under the Apache License, Version 2.0, see LICENSE for details. | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
{ | ||
setup_ssr: 1, | ||
parallelize_m: 0, | ||
parallelize_k: 0, | ||
m_tiles: 2, // number of tiles in M dimension | ||
n_tiles: 1, // number of tiles in N dimension | ||
k_tiles: 1, // number of tiles in K dimension | ||
load_a: 1, | ||
load_b: 1, | ||
load_c: 1, | ||
transa: false, | ||
transb: true, // must be true for SIMD | ||
M: 16, | ||
N: 16, | ||
K: 16, | ||
alpha: 1, | ||
beta: 0, | ||
a_gelu: -0.2888, | ||
b_gelu: -1.769, | ||
flg_fp: "fused_linear_gelu_fp64_naive" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright 2022 ETH Zurich and University of Bologna. | ||
# Licensed under the Apache License, Version 2.0, see LICENSE for details. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Authors: Viviane Potocnik <[email protected]> | ||
|
||
import numpy as np | ||
import os | ||
import re | ||
import pyflexfloat as ff | ||
import sys | ||
import torch | ||
|
||
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../util/sim/")) | ||
import data_utils # noqa: E402 | ||
from data_utils import DataGen, format_array_declaration, format_struct_definition, \ | ||
format_array_definition, format_ifdef_wrapper # noqa: E402 | ||
|
||
|
||
np.random.seed(42) | ||
|
||
def sigmoid_gelu(x, x_shape, a, b): | ||
# print(dir(ff.np)) | ||
# print(type(x)) | ||
x = ff.np.float32(x) | ||
# reshape and convert back to torch tensor | ||
# x = x.reshape(x_shape) | ||
x = torch.tensor(x) | ||
result = torch.sign(x) * (a * (torch.clamp(torch.abs(x), max=-b) + b)**2 + 1) | ||
# return as same type as input | ||
return ff.array(result.numpy(), x_shape) | ||
|
||
|
||
class FlgDataGen(DataGen): | ||
|
||
# AXI splits bursts crossing 4KB address boundaries. To minimize | ||
# the occurrence of these splits the data should be aligned to 4KB | ||
BURST_ALIGNMENT = 4096 | ||
|
||
def exact_golden_model(self, alpha, a, b, beta, c, a_gelu, b_gelu): | ||
M, N, K = a.shape[0], b.shape[1], b.shape[0] | ||
result = beta * c | ||
result_shape = result.shape | ||
for m in range(M): | ||
for n in range(N): | ||
for k in range(K): | ||
result[m][n] += a[m][k] * b[k][n] | ||
result[m][n] = sigmoid_gelu(result[m][n], result_shape, a_gelu, b_gelu) | ||
return result | ||
|
||
def infer_implementation(self, flg_fp): | ||
# flg_fp: "fused_linear_gelu_fp64_opt" | ||
# create a regex with fp_<type>_<implementation> | ||
prec, impl = re.search(r'fused_linear_gelu_fp(\d+)_(\w+)', flg_fp).group(1, 2) | ||
return (int(prec) / 8), impl | ||
|
||
def validate_config(self, flg_fp, parallelize_m, | ||
parallelize_k, m_tiles, n_tiles, k_tiles, transa, | ||
transb, M, N, K, beta, **kwargs): | ||
frac_m = M / m_tiles | ||
frac_n = N / n_tiles | ||
frac_k = K / k_tiles | ||
|
||
dtype, impl = self.infer_implementation(flg_fp) | ||
|
||
# Calculate total TCDM occupation | ||
# Note: doesn't account for double buffering | ||
prec = data_utils.size_from_precision_t(dtype) | ||
a_size = frac_m * frac_k * prec | ||
b_size = frac_k * frac_n * prec | ||
c_size = frac_m * frac_n * prec | ||
total_size = a_size | ||
total_size += b_size | ||
total_size += c_size | ||
data_utils.validate_tcdm_footprint(total_size) | ||
|
||
assert (M % m_tiles) == 0, 'M is not an integer multiple of tile size' | ||
assert (N % n_tiles) == 0, 'N is not an integer multiple of tile size' | ||
assert (K % k_tiles) == 0, 'K is not an integer multiple of tile size' | ||
assert not (parallelize_m and parallelize_k), 'Cannot parallelize K and M simultaneously' | ||
assert not transa, 'SIMD kernels don\'t support transposed A matrix' | ||
assert (dtype == 8) or (impl == 'baseline') or (impl == 'naive') \ | ||
or transb, 'Optimized SIMD kernels only support transposed B matrix' | ||
assert not transb or n_tiles == 1, 'Tiling in the N dimension not supported' \ | ||
' if B is transposed' | ||
assert not transb or k_tiles == 1, 'Tiling in the K dimension not supported' \ | ||
' if B is transposed' | ||
assert (impl == 'baseline') or (impl == 'naive') or frac_n >= 8, \ | ||
'N dimension of tile size must be greater or equal to the unrolling factor (8) ' \ | ||
'when using optimized kernels' | ||
assert beta == 0 or beta == 1, 'Only values of 0 or 1 supported for beta' | ||
assert not (dtype == 8 and impl == "baseline"), 'No baseline implemented' \ | ||
' for FP64 (switch to NAIVE)' | ||
assert not (((dtype == 8) or (dtype == 4)) and impl == "opt_ex"), \ | ||
'Expanding GEMM kernels' \ | ||
' not supported for FP64 and FP32' | ||
assert not (dtype == 1 and impl == "opt"), 'FP8 not supported in' \ | ||
' optimized implementation' \ | ||
' (switch to opt_ex)' | ||
|
||
def emit_header(self, **kwargs): | ||
header = [super().emit_header()] | ||
|
||
# Validate parameters | ||
self.validate_config(**kwargs) | ||
|
||
M, N, K = kwargs['M'], kwargs['N'], kwargs['K'] | ||
|
||
prec, _ = self.infer_implementation(kwargs['flg_fp']) | ||
|
||
ff_desc = data_utils.ff_desc_from_precision_t(prec) | ||
ctype = data_utils.ctype_from_precision_t(prec) | ||
|
||
a = ff.array(np.random.rand(M, K), ff_desc) | ||
b = ff.array(np.random.rand(K, N), ff_desc) | ||
c = ff.array(np.random.rand(M, N), ff_desc) | ||
|
||
# a = -0.2888 | ||
# b = -1.769 | ||
a_gelu = kwargs['a_gelu'] | ||
b_gelu = kwargs['b_gelu'] | ||
result = self.exact_golden_model(1, a, b, kwargs['beta'], c, a_gelu, b_gelu) | ||
|
||
# Store matrices in transposed form if requested | ||
a = a.T if kwargs['transa'] else a | ||
b = b.T if kwargs['transb'] else b | ||
|
||
a_uid = 'a' | ||
b_uid = 'b' | ||
c_uid = 'c' | ||
|
||
cfg = { | ||
'prec': prec, | ||
**kwargs, | ||
'a_gelu': a_gelu, | ||
'b_gelu': b_gelu, | ||
'a': a_uid, | ||
'b': b_uid, | ||
'c': c_uid, | ||
} | ||
|
||
a = a.flatten() | ||
b = b.flatten() | ||
c = c.flatten() | ||
|
||
header += [format_array_declaration(ctype, a_uid, a.shape)] | ||
header += [format_array_declaration(ctype, b_uid, b.shape)] | ||
header += [format_array_declaration(ctype, c_uid, c.shape)] | ||
header += [format_struct_definition('flg_args_t', 'flg_args', cfg)] | ||
header += [format_array_definition(ctype, a_uid, a, | ||
section=kwargs['section'])] | ||
header += [format_array_definition(ctype, b_uid, b, | ||
section=kwargs['section'])] | ||
header += [format_array_definition(ctype, c_uid, c, | ||
section=kwargs['section'])] | ||
result_def = format_array_definition(ctype, 'result', result.flatten()) | ||
header += [format_ifdef_wrapper('BIST', result_def)] | ||
header = '\n\n'.join(header) | ||
|
||
return header | ||
|
||
|
||
if __name__ == "__main__": | ||
sys.exit(FlgDataGen().main()) |
Oops, something went wrong.