Skip to content

Commit

Permalink
Add NEON and SVE implementations for Float16 conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
annop-w committed Nov 28, 2024
1 parent 357b54c commit 6d4228c
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 0 deletions.
35 changes: 35 additions & 0 deletions include/fbgemm/FbgemmConvert.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* Copyright 2024 Arm Limited and/or its affiliates <[email protected]>
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -135,6 +136,26 @@ FBGEMM_API void FloatToFloat16_avx512(
size_t size,
bool do_clip = false);

/**
* @brief SVE implementation to convert fp32 numbers to fp16 numbers.
*
*/
FBGEMM_API void FloatToFloat16_sve(
const float* src,
float16* dst,
size_t size,
bool do_clip = false);

/**
* @brief NEON implementation to convert fp32 numbers to fp16 numbers.
*
*/
FBGEMM_API void FloatToFloat16_neon(
const float* src,
float16* dst,
size_t size,
bool do_clip = false);

/**
* @brief AVX2 implementation to convert fp16 numbers to fp32 numbers.
*
Expand All @@ -149,6 +170,20 @@ Float16ToFloat_avx2(const float16* src, float* dst, size_t size);
FBGEMM_API void
Float16ToFloat_avx512(const float16* src, float* dst, size_t size);

/**
* @brief SVE implementation to convert fp16 numbers to fp32 numbers.
*
*/
FBGEMM_API void
Float16ToFloat_sve(const float16* src, float* dst, size_t size);

/**
* @brief NEON implementation to convert fp16 numbers to fp32 numbers.
*
*/
FBGEMM_API void
Float16ToFloat_neon(const float16* src, float* dst, size_t size);

/**
* @brief Transform all entries in a matrix from fp32 to float16 and back to
* fp32.
Expand Down
15 changes: 15 additions & 0 deletions src/FbgemmFloat16Convert.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* Copyright 2024 Arm Limited and/or its affiliates <[email protected]>
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -39,10 +40,17 @@ void FloatToFloat16_simd(
bool do_clip) {
// Run time CPU detection
if (cpuinfo_initialize()) {
#ifdef __aarch64__
if (fbgemmHasArmSveSupport()) {
FloatToFloat16_sve(src, dst, size, do_clip);
} else if (fbgemmHasArmNeonSupport()) {
FloatToFloat16_neon(src, dst, size, do_clip);
#else
if (fbgemmHasAvx512Support()) {
FloatToFloat16_avx512(src, dst, size, do_clip);
} else if (fbgemmHasAvx2Support()) {
FloatToFloat16_avx2(src, dst, size, do_clip);
#endif
} else {
FloatToFloat16_ref(src, dst, size, do_clip);
return;
Expand All @@ -55,10 +63,17 @@ void FloatToFloat16_simd(
void Float16ToFloat_simd(const float16* src, float* dst, size_t size) {
// Run time CPU detection
if (cpuinfo_initialize()) {
#ifdef __aarch64__
if (fbgemmHasArmSveSupport()) {
Float16ToFloat_sve(src, dst, size);
} else if (fbgemmHasArmNeonSupport()) {
Float16ToFloat_neon(src, dst, size);
#else
if (fbgemmHasAvx512Support()) {
Float16ToFloat_avx512(src, dst, size);
} else if (fbgemmHasAvx2Support()) {
Float16ToFloat_avx2(src, dst, size);
#endif
} else {
Float16ToFloat_ref(src, dst, size);
return;
Expand Down
122 changes: 122 additions & 0 deletions src/FbgemmFloat16ConvertNeon.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliate <[email protected]>
* SPDX-License-Identifier: BSD-3-Clause
*/

#include <arm_neon.h>
#define FBGEMM_EXPORTS
#include "fbgemm/FbgemmConvert.h"

namespace fbgemm {

void FloatToFloat16_neon(
const float* src,
float16* dst,
size_t size,
bool do_clip) {
if (do_clip) {
constexpr float FP16_MAX = 65504.f;
auto vpos = vdupq_n_f32(FP16_MAX);
auto vneg = vdupq_n_f32(-FP16_MAX);
size_t i = 0;
for (; i + 16 < size; i += 16) {
auto f32_vec1 = vld1q_f32(src + i);
auto f32_vec2 = vld1q_f32(src + i + 4);
auto f32_vec3 = vld1q_f32(src + i + 8);
auto f32_vec4 = vld1q_f32(src + i + 12);
f32_vec1 = vmaxq_f32(vminq_f32(f32_vec1, vpos), vneg);
f32_vec2 = vmaxq_f32(vminq_f32(f32_vec2, vpos), vneg);
f32_vec3 = vmaxq_f32(vminq_f32(f32_vec3, vpos), vneg);
f32_vec4 = vmaxq_f32(vminq_f32(f32_vec4, vpos), vneg);
auto f16_vec1 = vcvt_f16_f32(f32_vec1);
auto f16_vec2 = vcvt_f16_f32(f32_vec2);
auto f16_vec3 = vcvt_f16_f32(f32_vec3);
auto f16_vec4 = vcvt_f16_f32(f32_vec4);
vst1_f16((__fp16*)dst + i, f16_vec1);
vst1_f16((__fp16*)dst + i + 4, f16_vec2);
vst1_f16((__fp16*)dst + i + 8, f16_vec3);
vst1_f16((__fp16*)dst + i + 12, f16_vec4);
}
for (; i + 8 < size; i += 8) {
auto f32_vec1 = vld1q_f32(src + i);
auto f32_vec2 = vld1q_f32(src + i + 4);
f32_vec1 = vmaxq_f32(vminq_f32(f32_vec1, vpos), vneg);
f32_vec2 = vmaxq_f32(vminq_f32(f32_vec2, vpos), vneg);
auto f16_vec1 = vcvt_f16_f32(f32_vec1);
auto f16_vec2 = vcvt_f16_f32(f32_vec2);
vst1_f16((__fp16*)dst + i, f16_vec1);
vst1_f16((__fp16*)dst + i + 4, f16_vec2);
}
for (; i + 4 < size; i += 4) {
auto f32_vec = vld1q_f32(src + i);
f32_vec = vmaxq_f32(vminq_f32(f32_vec, vpos), vneg);
auto f16_vec = vcvt_f16_f32(f32_vec);
vst1_f16((__fp16*)dst + i, f16_vec);
}
FloatToFloat16_ref(src + i, dst + i, size - i, do_clip);
} else {
size_t i = 0;
for (; i + 16 < size; i += 16) {
auto f32_vec1 = vld1q_f32(src + i);
auto f32_vec2 = vld1q_f32(src + i + 4);
auto f32_vec3 = vld1q_f32(src + i + 8);
auto f32_vec4 = vld1q_f32(src + i + 12);
auto f16_vec1 = vcvt_f16_f32(f32_vec1);
auto f16_vec2 = vcvt_f16_f32(f32_vec2);
auto f16_vec3 = vcvt_f16_f32(f32_vec3);
auto f16_vec4 = vcvt_f16_f32(f32_vec4);
vst1_f16((__fp16*)dst + i, f16_vec1);
vst1_f16((__fp16*)dst + i + 4, f16_vec2);
vst1_f16((__fp16*)dst + i + 8, f16_vec3);
vst1_f16((__fp16*)dst + i + 12, f16_vec4);
}
for (; i + 8 < size; i += 8) {
auto f32_vec1 = vld1q_f32(src + i);
auto f32_vec2 = vld1q_f32(src + i + 4);
auto f16_vec1 = vcvt_f16_f32(f32_vec1);
auto f16_vec2 = vcvt_f16_f32(f32_vec2);
vst1_f16((__fp16*)dst + i, f16_vec1);
vst1_f16((__fp16*)dst + i + 4, f16_vec2);
}
for (; i + 4 < size; i += 4) {
auto f32_vec = vld1q_f32(src + i);
auto f16_vec = vcvt_f16_f32(f32_vec);
vst1_f16((__fp16*)dst + i, f16_vec);
}
FloatToFloat16_ref(src + i, dst + i, size - i);
}
}

void Float16ToFloat_neon(const float16* src, float* dst, size_t size) {
size_t i = 0;
for (; i + 16 < size; i += 16) {
auto f16_vec1 = vld1_f16((__fp16*)src + i);
auto f16_vec2 = vld1_f16((__fp16*)src + i + 4);
auto f16_vec3 = vld1_f16((__fp16*)src + i + 8);
auto f16_vec4 = vld1_f16((__fp16*)src + i + 12);
auto f32_vec1 = vcvt_f32_f16(f16_vec1);
auto f32_vec2 = vcvt_f32_f16(f16_vec2);
auto f32_vec3 = vcvt_f32_f16(f16_vec3);
auto f32_vec4 = vcvt_f32_f16(f16_vec4);
vst1q_f32(dst + i, f32_vec1);
vst1q_f32(dst + i + 4, f32_vec2);
vst1q_f32(dst + i + 8, f32_vec3);
vst1q_f32(dst + i + 12, f32_vec4);
}
for (; i + 8 < size; i += 8) {
auto f16_vec1 = vld1_f16((__fp16*)src + i);
auto f16_vec2 = vld1_f16((__fp16*)src + i + 4);
auto f32_vec1 = vcvt_f32_f16(f16_vec1);
auto f32_vec2 = vcvt_f32_f16(f16_vec2);
vst1q_f32(dst + i, f32_vec1);
vst1q_f32(dst + i + 4, f32_vec2);
}
for (; i + 4 < size; i += 4) {
auto f16_vec = vld1_f16((__fp16*)src + i);
auto f32_vec = vcvt_f32_f16(f16_vec);
vst1q_f32(dst + i, f32_vec);
}
Float16ToFloat_ref(src + i, dst + i, size - i);
}

} // namespace fbgemm
99 changes: 99 additions & 0 deletions src/FbgemmFloat16ConvertSve.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliate <[email protected]>
* SPDX-License-Identifier: BSD-3-Clause
*/

#include <arm_sve.h>
#define FBGEMM_EXPORTS
#include "fbgemm/FbgemmConvert.h"

namespace fbgemm {

void FloatToFloat16_sve(
const float* src,
float16* dst,
size_t size,
bool do_clip) {
if (do_clip) {
constexpr float FP16_MAX = 65504.f;
size_t i = 0;
int lanes = svcntw();
auto p_32 = svptrue_b32();
auto p_16 = svptrue_b16();
auto pfalse = svpfalse();
auto p_16_half = svuzp1_b16(p_16, pfalse);
while (i + 2 * lanes < size) {
auto f32_vec1 = svld1_f32(p_32, src + i);
auto f32_vec2 = svld1_f32(p_32, src + i + lanes);
f32_vec1 = svmax_n_f32_x(p_32, svmin_n_f32_x(p_32, f32_vec1, FP16_MAX), -FP16_MAX);
f32_vec2 = svmax_n_f32_x(p_32, svmin_n_f32_x(p_32, f32_vec2, FP16_MAX), -FP16_MAX);
auto f16_vec1 = svcvt_f16_f32_x(p_32, f32_vec1);
auto f16_vec2 = svcvt_f16_f32_x(p_32, f32_vec2);
auto f16_vec = svuzp1_f16(f16_vec1, f16_vec2);
svst1_f16(p_16, (__fp16*)dst + i, f16_vec);
i += 2 * lanes;
}
while (i + lanes < size) {
auto f32_vec = svld1_f32(p_32, src + i);
f32_vec = svmax_n_f32_x(p_32, svmin_n_f32_x(p_32, f32_vec, FP16_MAX), -FP16_MAX);
auto f16_vec = svcvt_f16_f32_x(p_16, f32_vec);
f16_vec = svuzp1_f16(f16_vec, f16_vec);
svst1_f16(p_16_half, (__fp16*)dst + i, f16_vec);
i += lanes;
}
FloatToFloat16_ref(src + i, dst + i, size - i, do_clip);
} else {
size_t i = 0;
int lanes = svcntw();
auto p_32 = svptrue_b32();
auto p_16 = svptrue_b16();
auto pfalse = svpfalse();
auto p_16_half = svuzp1_b16(p_16, pfalse);
while (i + 2 * lanes < size) {
auto f32_vec1 = svld1_f32(p_32, src + i);
auto f32_vec2 = svld1_f32(p_32, src + i + lanes);
auto f16_vec1 = svcvt_f16_f32_x(p_32, f32_vec1);
auto f16_vec2 = svcvt_f16_f32_x(p_32, f32_vec2);
auto f16_vec = svuzp1_f16(f16_vec1, f16_vec2);
svst1_f16(p_16, (__fp16*)dst + i, f16_vec);
i += 2 * lanes;
}
while (i + lanes < size) {
auto f32_vec = svld1_f32(p_32, src + i);
auto f16_vec = svcvt_f16_f32_x(p_32, f32_vec);
f16_vec = svuzp1_f16(f16_vec, f16_vec);
svst1_f16(p_16_half, (__fp16*)dst + i, f16_vec);
i += lanes;
}
FloatToFloat16_ref(src + i, dst + i, size - i);
}
}

void Float16ToFloat_sve(const float16* src, float* dst, size_t size) {
size_t i = 0;
int lanes = svcntw();
auto p_32 = svptrue_b32();
auto p_16 = svptrue_b16();
auto pfalse = svpfalse();
auto p_16_half = svuzp1_b16(p_16, pfalse);
while (i + 2 * lanes < size) {
auto f16_vec = svld1_f16(p_16, (__fp16*)src + i);
auto f16_vec1 = svzip1(f16_vec, f16_vec);
auto f16_vec2 = svzip2(f16_vec, f16_vec);
auto f32_vec1 = svcvt_f32_f16_x(p_16, f16_vec1);
auto f32_vec2 = svcvt_f32_f16_x(p_16, f16_vec2);
svst1_f32(p_32, dst + i, f32_vec1);
svst1_f32(p_32, dst + i + lanes, f32_vec2);
i += 2 * lanes;
}
while (i + lanes < size) {
auto f16_vec = svld1_f16(p_16_half, (__fp16*)src + i);
f16_vec = svzip1_f16(f16_vec, f16_vec);
auto f32_vec = svcvt_f32_f16_x(p_32, f16_vec);
svst1_f32(p_32, dst + i, f32_vec);
i += lanes;
}
Float16ToFloat_ref(src + i, dst + i, size - i);
}

} // namespace fbgemm

0 comments on commit 6d4228c

Please sign in to comment.