-
Notifications
You must be signed in to change notification settings - Fork 528
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add NEON and SVE implementations for Float16 conversions
- Loading branch information
Showing
4 changed files
with
271 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* Copyright 2024 Arm Limited and/or its affiliates <[email protected]> | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
|
@@ -135,6 +136,26 @@ FBGEMM_API void FloatToFloat16_avx512( | |
size_t size, | ||
bool do_clip = false); | ||
|
||
/** | ||
* @brief SVE implementation to convert fp32 numbers to fp16 numbers. | ||
* | ||
*/ | ||
FBGEMM_API void FloatToFloat16_sve( | ||
const float* src, | ||
float16* dst, | ||
size_t size, | ||
bool do_clip = false); | ||
|
||
/** | ||
* @brief NEON implementation to convert fp32 numbers to fp16 numbers. | ||
* | ||
*/ | ||
FBGEMM_API void FloatToFloat16_neon( | ||
const float* src, | ||
float16* dst, | ||
size_t size, | ||
bool do_clip = false); | ||
|
||
/** | ||
* @brief AVX2 implementation to convert fp16 numbers to fp32 numbers. | ||
* | ||
|
@@ -149,6 +170,20 @@ Float16ToFloat_avx2(const float16* src, float* dst, size_t size); | |
FBGEMM_API void | ||
Float16ToFloat_avx512(const float16* src, float* dst, size_t size); | ||
|
||
/** | ||
* @brief SVE implementation to convert fp16 numbers to fp32 numbers. | ||
* | ||
*/ | ||
FBGEMM_API void | ||
Float16ToFloat_sve(const float16* src, float* dst, size_t size); | ||
|
||
/** | ||
* @brief NEON implementation to convert fp16 numbers to fp32 numbers. | ||
* | ||
*/ | ||
FBGEMM_API void | ||
Float16ToFloat_neon(const float16* src, float* dst, size_t size); | ||
|
||
/** | ||
* @brief Transform all entries in a matrix from fp32 to float16 and back to | ||
* fp32. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* Copyright 2024 Arm Limited and/or its affiliates <[email protected]> | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
|
@@ -39,10 +40,17 @@ void FloatToFloat16_simd( | |
bool do_clip) { | ||
// Run time CPU detection | ||
if (cpuinfo_initialize()) { | ||
#ifdef __aarch64__ | ||
if (fbgemmHasArmSveSupport()) { | ||
FloatToFloat16_sve(src, dst, size, do_clip); | ||
} else if (fbgemmHasArmNeonSupport()) { | ||
FloatToFloat16_neon(src, dst, size, do_clip); | ||
#else | ||
if (fbgemmHasAvx512Support()) { | ||
FloatToFloat16_avx512(src, dst, size, do_clip); | ||
} else if (fbgemmHasAvx2Support()) { | ||
FloatToFloat16_avx2(src, dst, size, do_clip); | ||
#endif | ||
} else { | ||
FloatToFloat16_ref(src, dst, size, do_clip); | ||
return; | ||
|
@@ -55,10 +63,17 @@ void FloatToFloat16_simd( | |
void Float16ToFloat_simd(const float16* src, float* dst, size_t size) { | ||
// Run time CPU detection | ||
if (cpuinfo_initialize()) { | ||
#ifdef __aarch64__ | ||
if (fbgemmHasArmSveSupport()) { | ||
Float16ToFloat_sve(src, dst, size); | ||
} else if (fbgemmHasArmNeonSupport()) { | ||
Float16ToFloat_neon(src, dst, size); | ||
#else | ||
if (fbgemmHasAvx512Support()) { | ||
Float16ToFloat_avx512(src, dst, size); | ||
} else if (fbgemmHasAvx2Support()) { | ||
Float16ToFloat_avx2(src, dst, size); | ||
#endif | ||
} else { | ||
Float16ToFloat_ref(src, dst, size); | ||
return; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
/* | ||
* SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliate <[email protected]> | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
|
||
#include <arm_neon.h> | ||
#define FBGEMM_EXPORTS | ||
#include "fbgemm/FbgemmConvert.h" | ||
|
||
namespace fbgemm { | ||
|
||
void FloatToFloat16_neon( | ||
const float* src, | ||
float16* dst, | ||
size_t size, | ||
bool do_clip) { | ||
if (do_clip) { | ||
constexpr float FP16_MAX = 65504.f; | ||
auto vpos = vdupq_n_f32(FP16_MAX); | ||
auto vneg = vdupq_n_f32(-FP16_MAX); | ||
size_t i = 0; | ||
for (; i + 16 < size; i += 16) { | ||
auto f32_vec1 = vld1q_f32(src + i); | ||
auto f32_vec2 = vld1q_f32(src + i + 4); | ||
auto f32_vec3 = vld1q_f32(src + i + 8); | ||
auto f32_vec4 = vld1q_f32(src + i + 12); | ||
f32_vec1 = vmaxq_f32(vminq_f32(f32_vec1, vpos), vneg); | ||
f32_vec2 = vmaxq_f32(vminq_f32(f32_vec2, vpos), vneg); | ||
f32_vec3 = vmaxq_f32(vminq_f32(f32_vec3, vpos), vneg); | ||
f32_vec4 = vmaxq_f32(vminq_f32(f32_vec4, vpos), vneg); | ||
auto f16_vec1 = vcvt_f16_f32(f32_vec1); | ||
auto f16_vec2 = vcvt_f16_f32(f32_vec2); | ||
auto f16_vec3 = vcvt_f16_f32(f32_vec3); | ||
auto f16_vec4 = vcvt_f16_f32(f32_vec4); | ||
vst1_f16((__fp16*)dst + i, f16_vec1); | ||
vst1_f16((__fp16*)dst + i + 4, f16_vec2); | ||
vst1_f16((__fp16*)dst + i + 8, f16_vec3); | ||
vst1_f16((__fp16*)dst + i + 12, f16_vec4); | ||
} | ||
for (; i + 8 < size; i += 8) { | ||
auto f32_vec1 = vld1q_f32(src + i); | ||
auto f32_vec2 = vld1q_f32(src + i + 4); | ||
f32_vec1 = vmaxq_f32(vminq_f32(f32_vec1, vpos), vneg); | ||
f32_vec2 = vmaxq_f32(vminq_f32(f32_vec2, vpos), vneg); | ||
auto f16_vec1 = vcvt_f16_f32(f32_vec1); | ||
auto f16_vec2 = vcvt_f16_f32(f32_vec2); | ||
vst1_f16((__fp16*)dst + i, f16_vec1); | ||
vst1_f16((__fp16*)dst + i + 4, f16_vec2); | ||
} | ||
for (; i + 4 < size; i += 4) { | ||
auto f32_vec = vld1q_f32(src + i); | ||
f32_vec = vmaxq_f32(vminq_f32(f32_vec, vpos), vneg); | ||
auto f16_vec = vcvt_f16_f32(f32_vec); | ||
vst1_f16((__fp16*)dst + i, f16_vec); | ||
} | ||
FloatToFloat16_ref(src + i, dst + i, size - i, do_clip); | ||
} else { | ||
size_t i = 0; | ||
for (; i + 16 < size; i += 16) { | ||
auto f32_vec1 = vld1q_f32(src + i); | ||
auto f32_vec2 = vld1q_f32(src + i + 4); | ||
auto f32_vec3 = vld1q_f32(src + i + 8); | ||
auto f32_vec4 = vld1q_f32(src + i + 12); | ||
auto f16_vec1 = vcvt_f16_f32(f32_vec1); | ||
auto f16_vec2 = vcvt_f16_f32(f32_vec2); | ||
auto f16_vec3 = vcvt_f16_f32(f32_vec3); | ||
auto f16_vec4 = vcvt_f16_f32(f32_vec4); | ||
vst1_f16((__fp16*)dst + i, f16_vec1); | ||
vst1_f16((__fp16*)dst + i + 4, f16_vec2); | ||
vst1_f16((__fp16*)dst + i + 8, f16_vec3); | ||
vst1_f16((__fp16*)dst + i + 12, f16_vec4); | ||
} | ||
for (; i + 8 < size; i += 8) { | ||
auto f32_vec1 = vld1q_f32(src + i); | ||
auto f32_vec2 = vld1q_f32(src + i + 4); | ||
auto f16_vec1 = vcvt_f16_f32(f32_vec1); | ||
auto f16_vec2 = vcvt_f16_f32(f32_vec2); | ||
vst1_f16((__fp16*)dst + i, f16_vec1); | ||
vst1_f16((__fp16*)dst + i + 4, f16_vec2); | ||
} | ||
for (; i + 4 < size; i += 4) { | ||
auto f32_vec = vld1q_f32(src + i); | ||
auto f16_vec = vcvt_f16_f32(f32_vec); | ||
vst1_f16((__fp16*)dst + i, f16_vec); | ||
} | ||
FloatToFloat16_ref(src + i, dst + i, size - i); | ||
} | ||
} | ||
|
||
void Float16ToFloat_neon(const float16* src, float* dst, size_t size) { | ||
size_t i = 0; | ||
for (; i + 16 < size; i += 16) { | ||
auto f16_vec1 = vld1_f16((__fp16*)src + i); | ||
auto f16_vec2 = vld1_f16((__fp16*)src + i + 4); | ||
auto f16_vec3 = vld1_f16((__fp16*)src + i + 8); | ||
auto f16_vec4 = vld1_f16((__fp16*)src + i + 12); | ||
auto f32_vec1 = vcvt_f32_f16(f16_vec1); | ||
auto f32_vec2 = vcvt_f32_f16(f16_vec2); | ||
auto f32_vec3 = vcvt_f32_f16(f16_vec3); | ||
auto f32_vec4 = vcvt_f32_f16(f16_vec4); | ||
vst1q_f32(dst + i, f32_vec1); | ||
vst1q_f32(dst + i + 4, f32_vec2); | ||
vst1q_f32(dst + i + 8, f32_vec3); | ||
vst1q_f32(dst + i + 12, f32_vec4); | ||
} | ||
for (; i + 8 < size; i += 8) { | ||
auto f16_vec1 = vld1_f16((__fp16*)src + i); | ||
auto f16_vec2 = vld1_f16((__fp16*)src + i + 4); | ||
auto f32_vec1 = vcvt_f32_f16(f16_vec1); | ||
auto f32_vec2 = vcvt_f32_f16(f16_vec2); | ||
vst1q_f32(dst + i, f32_vec1); | ||
vst1q_f32(dst + i + 4, f32_vec2); | ||
} | ||
for (; i + 4 < size; i += 4) { | ||
auto f16_vec = vld1_f16((__fp16*)src + i); | ||
auto f32_vec = vcvt_f32_f16(f16_vec); | ||
vst1q_f32(dst + i, f32_vec); | ||
} | ||
Float16ToFloat_ref(src + i, dst + i, size - i); | ||
} | ||
|
||
} // namespace fbgemm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
/* | ||
* SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliate <[email protected]> | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
|
||
#include <arm_sve.h> | ||
#define FBGEMM_EXPORTS | ||
#include "fbgemm/FbgemmConvert.h" | ||
|
||
namespace fbgemm { | ||
|
||
void FloatToFloat16_sve( | ||
const float* src, | ||
float16* dst, | ||
size_t size, | ||
bool do_clip) { | ||
if (do_clip) { | ||
constexpr float FP16_MAX = 65504.f; | ||
size_t i = 0; | ||
int lanes = svcntw(); | ||
auto p_32 = svptrue_b32(); | ||
auto p_16 = svptrue_b16(); | ||
auto pfalse = svpfalse(); | ||
auto p_16_half = svuzp1_b16(p_16, pfalse); | ||
while (i + 2 * lanes < size) { | ||
auto f32_vec1 = svld1_f32(p_32, src + i); | ||
auto f32_vec2 = svld1_f32(p_32, src + i + lanes); | ||
f32_vec1 = svmax_n_f32_x(p_32, svmin_n_f32_x(p_32, f32_vec1, FP16_MAX), -FP16_MAX); | ||
f32_vec2 = svmax_n_f32_x(p_32, svmin_n_f32_x(p_32, f32_vec2, FP16_MAX), -FP16_MAX); | ||
auto f16_vec1 = svcvt_f16_f32_x(p_32, f32_vec1); | ||
auto f16_vec2 = svcvt_f16_f32_x(p_32, f32_vec2); | ||
auto f16_vec = svuzp1_f16(f16_vec1, f16_vec2); | ||
svst1_f16(p_16, (__fp16*)dst + i, f16_vec); | ||
i += 2 * lanes; | ||
} | ||
while (i + lanes < size) { | ||
auto f32_vec = svld1_f32(p_32, src + i); | ||
f32_vec = svmax_n_f32_x(p_32, svmin_n_f32_x(p_32, f32_vec, FP16_MAX), -FP16_MAX); | ||
auto f16_vec = svcvt_f16_f32_x(p_16, f32_vec); | ||
f16_vec = svuzp1_f16(f16_vec, f16_vec); | ||
svst1_f16(p_16_half, (__fp16*)dst + i, f16_vec); | ||
i += lanes; | ||
} | ||
FloatToFloat16_ref(src + i, dst + i, size - i, do_clip); | ||
} else { | ||
size_t i = 0; | ||
int lanes = svcntw(); | ||
auto p_32 = svptrue_b32(); | ||
auto p_16 = svptrue_b16(); | ||
auto pfalse = svpfalse(); | ||
auto p_16_half = svuzp1_b16(p_16, pfalse); | ||
while (i + 2 * lanes < size) { | ||
auto f32_vec1 = svld1_f32(p_32, src + i); | ||
auto f32_vec2 = svld1_f32(p_32, src + i + lanes); | ||
auto f16_vec1 = svcvt_f16_f32_x(p_32, f32_vec1); | ||
auto f16_vec2 = svcvt_f16_f32_x(p_32, f32_vec2); | ||
auto f16_vec = svuzp1_f16(f16_vec1, f16_vec2); | ||
svst1_f16(p_16, (__fp16*)dst + i, f16_vec); | ||
i += 2 * lanes; | ||
} | ||
while (i + lanes < size) { | ||
auto f32_vec = svld1_f32(p_32, src + i); | ||
auto f16_vec = svcvt_f16_f32_x(p_32, f32_vec); | ||
f16_vec = svuzp1_f16(f16_vec, f16_vec); | ||
svst1_f16(p_16_half, (__fp16*)dst + i, f16_vec); | ||
i += lanes; | ||
} | ||
FloatToFloat16_ref(src + i, dst + i, size - i); | ||
} | ||
} | ||
|
||
void Float16ToFloat_sve(const float16* src, float* dst, size_t size) { | ||
size_t i = 0; | ||
int lanes = svcntw(); | ||
auto p_32 = svptrue_b32(); | ||
auto p_16 = svptrue_b16(); | ||
auto pfalse = svpfalse(); | ||
auto p_16_half = svuzp1_b16(p_16, pfalse); | ||
while (i + 2 * lanes < size) { | ||
auto f16_vec = svld1_f16(p_16, (__fp16*)src + i); | ||
auto f16_vec1 = svzip1(f16_vec, f16_vec); | ||
auto f16_vec2 = svzip2(f16_vec, f16_vec); | ||
auto f32_vec1 = svcvt_f32_f16_x(p_16, f16_vec1); | ||
auto f32_vec2 = svcvt_f32_f16_x(p_16, f16_vec2); | ||
svst1_f32(p_32, dst + i, f32_vec1); | ||
svst1_f32(p_32, dst + i + lanes, f32_vec2); | ||
i += 2 * lanes; | ||
} | ||
while (i + lanes < size) { | ||
auto f16_vec = svld1_f16(p_16_half, (__fp16*)src + i); | ||
f16_vec = svzip1_f16(f16_vec, f16_vec); | ||
auto f32_vec = svcvt_f32_f16_x(p_32, f16_vec); | ||
svst1_f32(p_32, dst + i, f32_vec); | ||
i += lanes; | ||
} | ||
Float16ToFloat_ref(src + i, dst + i, size - i); | ||
} | ||
|
||
} // namespace fbgemm |