-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'k230' of github.com:ucb-bar/Baremetal-NN into main
- Loading branch information
Showing
11 changed files
with
1,337 additions
and
109 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
|
||
# set the RISCV option to ON | ||
option(RISCV "Build for RISC-V" ON) | ||
|
||
# CMake toolchain definition for RISC-V GCC toolchain | ||
set(CMAKE_SYSTEM_NAME "Linux" CACHE STRING "") | ||
set(CMAKE_SYSTEM_PROCESSOR "k230" CACHE STRING "") | ||
|
||
set(TOOLCHAIN_PREFIX "riscv64-unknown-linux-musl-") | ||
|
||
set(CMAKE_C_COMPILER "${TOOLCHAIN_PREFIX}gcc") | ||
set(CMAKE_ASM_COMPILER "${TOOLCHAIN_PREFIX}gcc") | ||
set(CMAKE_CXX_COMPILER "${TOOLCHAIN_PREFIX}g++") | ||
set(CMAKE_AR "${TOOLCHAIN_PREFIX}ar") | ||
set(CMAKE_LINKER "${TOOLCHAIN_PREFIX}ld") | ||
set(CMAKE_OBJCOPY "${TOOLCHAIN_PREFIX}objcopy") | ||
set(CMAKE_SIZE "${TOOLCHAIN_PREFIX}size") | ||
set(CMAKE_STRIP "${TOOLCHAIN_PREFIX}ld") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
cmake_minimum_required(VERSION 3.10) | ||
|
||
set(cpu_impl | ||
./impl/cpu/abs.c | ||
./impl/cpu/acc.c | ||
./impl/cpu/acc1.c | ||
./impl/cpu/add.c | ||
./impl/cpu/add1.c | ||
./impl/cpu/div.c | ||
./impl/cpu/dot.c | ||
./impl/cpu/fill.c | ||
./impl/cpu/max.c | ||
./impl/cpu/maximum.c | ||
./impl/cpu/maximum1.c | ||
./impl/cpu/min.c | ||
./impl/cpu/minimum.c | ||
./impl/cpu/minimum1.c | ||
./impl/cpu/mul.c | ||
./impl/cpu/mul1.c | ||
./impl/cpu/neg.c | ||
./impl/cpu/norm.c | ||
./impl/cpu/rms_norm.c | ||
./impl/cpu/sgn.c | ||
./impl/cpu/softmax.c | ||
./impl/cpu/sqr.c | ||
./impl/cpu/sqrt.c | ||
./impl/cpu/sub.c | ||
./impl/cpu/sum.c | ||
./impl/cpu/transpose.c | ||
) | ||
|
||
|
||
if (AVX) | ||
message(STATUS "Using AVX implementation") | ||
add_compile_definitions(AVX) | ||
endif () | ||
|
||
if (RVV) | ||
message(STATUS "Using RVV implementation") | ||
add_compile_definitions(RVV) | ||
|
||
if (RISCV_ZVFH) | ||
message(STATUS "Using Zvfh extension") | ||
add_compile_definitions(RISCV_ZVFH) | ||
endif () | ||
|
||
set(rvv_impl | ||
./impl/rvv/abs.c | ||
./impl/rvv/acc.c | ||
./impl/rvv/acc1.c | ||
./impl/rvv/add.c | ||
./impl/rvv/add1.c | ||
./impl/rvv/div.c | ||
./impl/rvv/dot.c | ||
./impl/rvv/max.c | ||
./impl/rvv/maximum.c | ||
./impl/rvv/maximum1.c | ||
./impl/rvv/min.c | ||
./impl/rvv/minimum.c | ||
./impl/rvv/minimum1.c | ||
./impl/rvv/mul.c | ||
./impl/rvv/mul1.c | ||
./impl/rvv/neg.c | ||
./impl/rvv/rms_norm.c | ||
./impl/rvv/sub.c | ||
./impl/rvv/transpose.c | ||
) | ||
endif () | ||
|
||
if (RVV_ASM) | ||
message(STATUS "Using RVV assembly implementation") | ||
|
||
set(rvv_impl | ||
./impl/rvv/abs.S | ||
./impl/rvv/add.S | ||
./impl/rvv/dot.S | ||
) | ||
endif () | ||
|
||
if (GEMMINI) | ||
message(STATUS "Using Gemmini implementation") | ||
add_compile_definitions(GEMMINI) | ||
|
||
set(gemmini_impl | ||
impl/gemmini/mm.c | ||
) | ||
endif () | ||
|
||
|
||
add_library(nn | ||
./functional/nn_tensor_creation.c | ||
./functional/nn_print.c | ||
./functional/nn_abs.c | ||
./functional/nn_add.c | ||
./functional/nn_batch_norm2d.c | ||
./functional/nn_conv2d.c | ||
./functional/nn_clip.c | ||
./functional/nn_copy.c | ||
./functional/nn_div.c | ||
./functional/nn_elu.c | ||
./functional/nn_fill.c | ||
./functional/nn_interpolate.c | ||
./functional/nn_layer_norm.c | ||
./functional/nn_linear.c | ||
./functional/nn_matmul.c | ||
./functional/nn_mm.c | ||
./functional/nn_norm.c | ||
./functional/nn_max.c | ||
./functional/nn_maximum.c | ||
./functional/nn_max_pool2d.c | ||
./functional/nn_min.c | ||
./functional/nn_minimum.c | ||
./functional/nn_mul.c | ||
./functional/nn_mv.c | ||
./functional/nn_neg.c | ||
./functional/nn_relu.c | ||
./functional/nn_relu6.c | ||
./functional/nn_rms_norm.c | ||
./functional/nn_softmax.c | ||
./functional/nn_silu.c | ||
./functional/nn_sub.c | ||
./functional/nn_sum.c | ||
./functional/nn_transpose.c | ||
|
||
${rvv_impl} | ||
${gemmini_impl} | ||
${cpu_impl} | ||
) | ||
|
||
target_include_directories(nn PUBLIC ./) | ||
|
||
if (X86) | ||
message(STATUS "nn: Building for x86") | ||
target_link_libraries(nn target-x86) | ||
|
||
elseif (RISCV) | ||
message(STATUS "nn: Building for RISC-V") | ||
target_link_libraries(nn target-riscv) | ||
endif () | ||
|
||
|
||
target_link_libraries(nn m) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
|
||
.globl NN__abs_f32 | ||
NN__abs_f32: | ||
beqz a0,__abs_f32_exit | ||
slli a4,a4,0x2 | ||
slli a2,a2,0x2 | ||
__abs_f32_loop: | ||
vsetvli a5,a0,e32,m1,ta,ma | ||
vlse32.v v24,(a3),a4 | ||
vfabs.v v24,v24 | ||
vsse32.v v24,(a1),a2 | ||
slli a6,a5,0x2 | ||
add a3,a3,a6 | ||
add a1,a1,a6 | ||
sub a0,a0,a5 | ||
bnez a0,__abs_f32_loop | ||
__abs_f32_exit: | ||
ret |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
|
||
.globl NN__add_f32 | ||
NN__add_f32: | ||
beqz a0,__add_f32_exit | ||
slli a4,a4,0x2 | ||
slli a6,a6,0x2 | ||
slli a2,a2,0x2 | ||
__add_f32_loop: | ||
vsetvli a7,a0,e32,m1,ta,ma | ||
vlse32.v v24,(a3),a4 | ||
vlse32.v v25,(a5),a6 | ||
vfadd.vv v24,v24,v25 | ||
vsse32.v v24,(a1),a2 | ||
slli t1,a7,0x2 | ||
add a3,a3,t1 | ||
add a5,a5,t1 | ||
add a1,a1,t1 | ||
sub a0,a0,a7 | ||
bnez a0,__add_f32_loop | ||
__add_f32_exit: | ||
ret |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
|
||
.globl NN__add1_f32 | ||
NN__add1_f32: | ||
beqz a0,__add1_f32_exit | ||
slli a4,a4,0x2 | ||
slli a2,a2,0x2 | ||
__add1_f32_loop: | ||
vsetvli a5,a0,e32,m1,ta,ma | ||
vlse32.v v24,(a3),a4 | ||
vfmv.v.f v25,fa0 | ||
vfadd.vv v24,v24,v25 | ||
vsse32.v v24,(a1),a2 | ||
slli a6,a5,0x2 | ||
add a3,a3,a6 | ||
add a1,a1,a6 | ||
sub a0,a0,a5 | ||
bnez a0,__add1_f32_loop | ||
__add1_f32_exit: | ||
ret |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
|
||
.globl NN__dot_f32 | ||
NN__dot_f32: | ||
vsetvli t1,zero,e32,m1,ta,ma | ||
vmv.v.i v27,0 | ||
vmv1r.v v24,v27 | ||
beqz a0,__dot_f32_exit | ||
slli a3,a3,0x2 | ||
slli a5,a5,0x2 | ||
__dot_f32_loop: | ||
vsetvli a6,a0,e32,m1,ta,ma | ||
vlse32.v v26,(a2),a3 | ||
vlse32.v v25,(a4),a5 | ||
vfmacc.vv v24,v26,v25 | ||
slli a7,a6,0x2 | ||
add a2,a2,a7 | ||
add a4,a4,a7 | ||
sub a0,a0,a6 | ||
bnez a0,__dot_f32_loop | ||
vsetvli t1,zero,e32,m1,ta,ma | ||
__dot_f32_exit: | ||
vfredusum.vs v24,v24,v27 | ||
vfmv.f.s fa5,v24 | ||
fsw fa5,0(a1) | ||
ret |
Oops, something went wrong.