Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenBLAS] Build the BFloat16 kernels in OpenBLAS #7202

Merged
merged 3 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion O/OpenBLAS/[email protected]/build_tarballs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,7 @@ dependencies = openblas_dependencies(platforms)

# Build the tarballs
build_tarballs(ARGS, name, version, sources, script, platforms, products, dependencies;
preferred_gcc_version=v"6", lock_microarchitecture=false, julia_compat="1.10")
preferred_gcc_version=v"6", preferred_llvm_version=v"13.0.1", lock_microarchitecture=false, julia_compat="1.10")


# Build trigger: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
From d2fc4f3b4d7f41527bc7dc8f62e9aa6229cfac89 Mon Sep 17 00:00:00 2001
From: Martin Kroeker <[email protected]>
Date: Wed, 17 Jan 2024 20:59:24 +0100
Subject: [PATCH] Increase multithreading threshold by a factor of 50

---
interface/gemv.c | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/interface/gemv.c b/interface/gemv.c
index 1f07635799..2c121f1308 100644
--- a/interface/gemv.c
+++ b/interface/gemv.c
@@ -226,7 +226,7 @@ void CNAME(enum CBLAS_ORDER order,

#ifdef SMP

- if ( 1L * m * n < 2304L * GEMM_MULTITHREAD_THRESHOLD )
+ if ( 1L * m * n < 115200L * GEMM_MULTITHREAD_THRESHOLD )
nthreads = 1;
else
nthreads = num_cpu_avail(2);

This file was deleted.

2 changes: 1 addition & 1 deletion O/OpenBLAS/[email protected]/build_tarballs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "OpenBLAS32"
version = v"0.3.26"

sources = openblas_sources(version)
script = openblas_script(openblas32=true)
script = openblas_script(openblas32=true, bfloat16=true)
platforms = openblas_platforms()
products = openblas_products()
dependencies = openblas_dependencies(platforms)
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
From d2fc4f3b4d7f41527bc7dc8f62e9aa6229cfac89 Mon Sep 17 00:00:00 2001
From: Martin Kroeker <[email protected]>
Date: Wed, 17 Jan 2024 20:59:24 +0100
Subject: [PATCH] Increase multithreading threshold by a factor of 50

---
interface/gemv.c | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/interface/gemv.c b/interface/gemv.c
index 1f07635799..2c121f1308 100644
--- a/interface/gemv.c
+++ b/interface/gemv.c
@@ -226,7 +226,7 @@ void CNAME(enum CBLAS_ORDER order,

#ifdef SMP

- if ( 1L * m * n < 2304L * GEMM_MULTITHREAD_THRESHOLD )
+ if ( 1L * m * n < 115200L * GEMM_MULTITHREAD_THRESHOLD )
nthreads = 1;
else
nthreads = num_cpu_avail(2);
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
From 1dada6d65d89d19b2cf89b12169f6b2196c90f1d Mon Sep 17 00:00:00 2001
From: Martin Kroeker <[email protected]>
Date: Fri, 12 Jan 2024 00:10:56 +0100
Subject: [PATCH 1/2] Add compiler test and flag for AVX512BF16 capability

---
c_check | 22 ++++++++++++++++++++++
1 file changed, 22 insertions(+)

diff --git a/c_check b/c_check
index b5e4a9ad00..3e507be818 100755
--- a/c_check
+++ b/c_check
@@ -244,6 +244,7 @@ case "$data" in
esac

no_avx512=0
+no_avx512bf=0
if [ "$architecture" = "x86" ] || [ "$architecture" = "x86_64" ]; then
tmpd=$(mktemp -d 2>/dev/null || mktemp -d -t 'OBC')
tmpf="$tmpd/a.c"
@@ -262,6 +263,25 @@ if [ "$architecture" = "x86" ] || [ "$architecture" = "x86_64" ]; then
}

rm -rf "$tmpd"
+ if [ "$no_avx512" -eq 0 ]; then
+ tmpd=$(mktemp -d 2>/dev/null || mktemp -d -t 'OBC')
+ tmpf="$tmpd/a.c"
+ code='"__m512 a= _mm512_dpbf16_ps(a, (__m512bh) _mm512_loadu_si512(%1]), (__m512bh) _mm512_loadu_si512(%2]));"'
+ printf "#include <immintrin.h>\n\nint main(void){ %s; }\n" "$code" >> "$tmpf"
+ if [ "$compiler" = "PGI" ]; then
+ args=" -tp cooperlake -c -o $tmpf.o $tmpf"
+ else
+ args=" -march=cooperlake -c -o $tmpf.o $tmpf"
+ fi
+ no_avx512bf=0
+ {
+ $compiler_name $flags $args >/dev/null 2>&1
+ } || {
+ no_avx512bf=1
+ }
+
+ rm -rf "$tmpd"
+ fi
fi

no_rv64gv=0
@@ -409,6 +429,7 @@ done
[ "$makefile" = "-" ] && {
[ "$no_rv64gv" -eq 1 ] && printf "NO_RV64GV=1\n"
[ "$no_avx512" -eq 1 ] && printf "NO_AVX512=1\n"
+ [ "$no_avx512bf" -eq 1 ] && printf "NO_AVX512BF16=1\n"
[ "$no_avx2" -eq 1 ] && printf "NO_AVX2=1\n"
[ "$oldgcc" -eq 1 ] && printf "OLDGCC=1\n"
exit 0
@@ -437,6 +458,7 @@ done
[ "$no_sve" -eq 1 ] && printf "NO_SVE=1\n"
[ "$no_rv64gv" -eq 1 ] && printf "NO_RV64GV=1\n"
[ "$no_avx512" -eq 1 ] && printf "NO_AVX512=1\n"
+ [ "$no_avx512bf" -eq 1 ] && printf "NO_AVX512BF16=1\n"
[ "$no_avx2" -eq 1 ] && printf "NO_AVX2=1\n"
[ "$oldgcc" -eq 1 ] && printf "OLDGCC=1\n"
[ "$no_lsx" -eq 1 ] && printf "NO_LSX=1\n"

From 995a990e24fdcc8080128a8abc17b4ccc66bd4fd Mon Sep 17 00:00:00 2001
From: Martin Kroeker <[email protected]>
Date: Fri, 12 Jan 2024 00:12:46 +0100
Subject: [PATCH 2/2] Make AVX512 BFLOAT16 kernels conditional on compiler
capability

---
kernel/x86_64/KERNEL.COOPERLAKE | 3 ++-
kernel/x86_64/KERNEL.SAPPHIRERAPIDS | 2 ++
2 files changed, 4 insertions(+), 1 deletion(-)

diff --git a/kernel/x86_64/KERNEL.COOPERLAKE b/kernel/x86_64/KERNEL.COOPERLAKE
index dba94aea86..22b042029f 100644
--- a/kernel/x86_64/KERNEL.COOPERLAKE
+++ b/kernel/x86_64/KERNEL.COOPERLAKE
@@ -1,5 +1,5 @@
include $(KERNELDIR)/KERNEL.SKYLAKEX
-
+ifneq ($(NO_AVX512BF16), 1)
SBGEMM_SMALL_M_PERMIT = sbgemm_small_kernel_permit_cooperlake.c
SBGEMM_SMALL_K_NN = sbgemm_small_kernel_nn_cooperlake.c
SBGEMM_SMALL_K_B0_NN = sbgemm_small_kernel_nn_cooperlake.c
@@ -20,3 +20,4 @@ SBGEMMINCOPYOBJ = sbgemm_incopy$(TSUFFIX).$(SUFFIX)
SBGEMMITCOPYOBJ = sbgemm_itcopy$(TSUFFIX).$(SUFFIX)
SBGEMMONCOPYOBJ = sbgemm_oncopy$(TSUFFIX).$(SUFFIX)
SBGEMMOTCOPYOBJ = sbgemm_otcopy$(TSUFFIX).$(SUFFIX)
+endif
diff --git a/kernel/x86_64/KERNEL.SAPPHIRERAPIDS b/kernel/x86_64/KERNEL.SAPPHIRERAPIDS
index 3a832e9174..0ab2b4ddcf 100644
--- a/kernel/x86_64/KERNEL.SAPPHIRERAPIDS
+++ b/kernel/x86_64/KERNEL.SAPPHIRERAPIDS
@@ -1,5 +1,6 @@
include $(KERNELDIR)/KERNEL.COOPERLAKE

+ifneq ($(NO_AVX512BF16), 1)
SBGEMM_SMALL_M_PERMIT =
SBGEMM_SMALL_K_NN =
SBGEMM_SMALL_K_B0_NN =
@@ -20,3 +21,4 @@ SBGEMMINCOPYOBJ = sbgemm_incopy$(TSUFFIX).$(SUFFIX)
SBGEMMITCOPYOBJ = sbgemm_itcopy$(TSUFFIX).$(SUFFIX)
SBGEMMONCOPYOBJ = sbgemm_oncopy$(TSUFFIX).$(SUFFIX)
SBGEMMOTCOPYOBJ = sbgemm_otcopy$(TSUFFIX).$(SUFFIX)
+endif
34 changes: 34 additions & 0 deletions O/OpenBLAS/[email protected]/bundled/patches/90-darwin-sve.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
From 03688a42622cf76e696859ce384e45aa26d927fc Mon Sep 17 00:00:00 2001
From: Ian McInerney <[email protected]>
Date: Tue, 23 Jan 2024 10:29:57 +0000
Subject: [PATCH] Build with proper aarch64 flags on Neoverse Darwin

We aren't affected by the problems in AppleClang that prompted this
fallback to an older architecture.
---
Makefile.arm64 | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/Makefile.arm64 b/Makefile.arm64
index ed52a9424..a8f3cb0f0 100644
--- a/Makefile.arm64
+++ b/Makefile.arm64
@@ -135,11 +135,11 @@ ifeq ($(CORE), NEOVERSEN2)
ifeq (1, $(filter 1,$(GCCVERSIONGTEQ7) $(ISCLANG)))
ifeq (1, $(filter 1,$(GCCVERSIONGTEQ10) $(ISCLANG)))
ifeq (1, $(filter 1,$(GCCMINORVERSIONGTEQ4) $(GCCVERSIONGTEQ11) $(ISCLANG)))
-ifneq ($(OSNAME), Darwin)
+#ifneq ($(OSNAME), Darwin)
CCOMMON_OPT += -march=armv8.5-a+sve+sve2+bf16 -mtune=neoverse-n2
-else
-CCOMMON_OPT += -march=armv8.2-a -mtune=cortex-a72
-endif
+#else
+#CCOMMON_OPT += -march=armv8.2-a -mtune=cortex-a72
+#endif
ifneq ($(F_COMPILER), NAG)
FCOMMON_OPT += -march=armv8.5-a+sve+sve2+bf16 -mtune=neoverse-n2
endif
--
2.43.0

4 changes: 2 additions & 2 deletions O/OpenBLAS/[email protected]/build_tarballs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ dependencies = openblas_dependencies(platforms)

# Build the tarballs
build_tarballs(ARGS, name, version, sources, script, platforms, products, dependencies;
preferred_gcc_version=v"6", lock_microarchitecture=false, julia_compat="1.10")
preferred_gcc_version=v"6", preferred_llvm_version=v"13.0.1", lock_microarchitecture=false, julia_compat="1.10")

# Build trigger: 3
# Build trigger: 4
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
From d2fc4f3b4d7f41527bc7dc8f62e9aa6229cfac89 Mon Sep 17 00:00:00 2001
From: Martin Kroeker <[email protected]>
Date: Wed, 17 Jan 2024 20:59:24 +0100
Subject: [PATCH] Increase multithreading threshold by a factor of 50

---
interface/gemv.c | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/interface/gemv.c b/interface/gemv.c
index 1f07635799..2c121f1308 100644
--- a/interface/gemv.c
+++ b/interface/gemv.c
@@ -226,7 +226,7 @@ void CNAME(enum CBLAS_ORDER order,

#ifdef SMP

- if ( 1L * m * n < 2304L * GEMM_MULTITHREAD_THRESHOLD )
+ if ( 1L * m * n < 115200L * GEMM_MULTITHREAD_THRESHOLD )
nthreads = 1;
else
nthreads = num_cpu_avail(2);

This file was deleted.

2 changes: 1 addition & 1 deletion O/OpenBLAS/[email protected]/build_tarballs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "OpenBLAS"
version = v"0.3.26"

sources = openblas_sources(version)
script = openblas_script(;aarch64_ilp64=true, num_64bit_threads=512)
script = openblas_script(;aarch64_ilp64=true, num_64bit_threads=512, bfloat16=true)
ViralBShah marked this conversation as resolved.
Show resolved Hide resolved
platforms = openblas_platforms(;experimental=true)
push!(platforms, Platform("x86_64", "linux"; sanitize="memory"))
products = openblas_products()
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
From d2fc4f3b4d7f41527bc7dc8f62e9aa6229cfac89 Mon Sep 17 00:00:00 2001
From: Martin Kroeker <[email protected]>
Date: Wed, 17 Jan 2024 20:59:24 +0100
Subject: [PATCH] Increase multithreading threshold by a factor of 50

---
interface/gemv.c | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/interface/gemv.c b/interface/gemv.c
index 1f07635799..2c121f1308 100644
--- a/interface/gemv.c
+++ b/interface/gemv.c
@@ -226,7 +226,7 @@ void CNAME(enum CBLAS_ORDER order,

#ifdef SMP

- if ( 1L * m * n < 2304L * GEMM_MULTITHREAD_THRESHOLD )
+ if ( 1L * m * n < 115200L * GEMM_MULTITHREAD_THRESHOLD )
nthreads = 1;
else
nthreads = num_cpu_avail(2);
Loading