Skip to content

Commit

Permalink
ENH: Add preliminary RVV support using neon2rvv (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
ksco authored May 8, 2024
1 parent 970d7e3 commit edee45e
Show file tree
Hide file tree
Showing 25 changed files with 129 additions and 42 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@
[submodule "numpy/_core/src/common/pythoncapi-compat"]
path = numpy/_core/src/common/pythoncapi-compat
url = https://github.com/python/pythoncapi-compat
[submodule "numpy/_core/src/neon2rvv"]
path = numpy/_core/src/neon2rvv
url = https://github.com/ksco/neon2rvv.git
28 changes: 22 additions & 6 deletions numpy/_core/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,7 @@ _umath_tests_mtargets = mod_features.multi_targets(
ASIMDHP, ASIMD, NEON,
VSX3, VSX2, VSX,
VXE, VX,
RVV,
],
baseline: CPU_BASELINE,
prefix: 'NPY_',
Expand Down Expand Up @@ -772,7 +773,8 @@ foreach gen_mtargets : [
AVX512_SKX, AVX2, XOP, SSE42, SSE2,
VSX2,
ASIMD, NEON,
VXE, VX
VXE, VX,
RVV,
]
],
]
Expand All @@ -790,7 +792,8 @@ foreach gen_mtargets : [
'src/multiarray',
'src/multiarray/stringdtype',
'src/npymath',
'src/umath'
'src/umath',
'src/neon2rvv',
]
)
if not is_variable('multiarray_umath_mtargets')
Expand Down Expand Up @@ -848,6 +851,7 @@ foreach gen_mtargets : [
'src/npymath',
'src/umath',
'src/highway',
'src/neon2rvv',
]
)
if not is_variable('multiarray_umath_mtargets')
Expand All @@ -874,6 +878,7 @@ foreach gen_mtargets : [
ASIMD, NEON,
VSX3, VSX2,
VXE, VX,
# RVV, triggered gcc 14.0.1 ICE
]
],
[
Expand All @@ -884,6 +889,7 @@ foreach gen_mtargets : [
NEON,
VSX4, VSX2,
VX,
RVV,
]
],
[
Expand All @@ -894,6 +900,7 @@ foreach gen_mtargets : [
VSX3, VSX2,
NEON,
VXE, VX,
RVV,
]
],
[
Expand All @@ -910,7 +917,7 @@ foreach gen_mtargets : [
AVX512_SKX, [AVX2, FMA3],
VSX4, VSX2,
NEON_VFPV4,
VXE, VX
VXE, VX,
]
],
[
Expand All @@ -921,6 +928,7 @@ foreach gen_mtargets : [
AVX512_SKX, AVX2, SSE2,
VSX2,
VX,
RVV,
]
],
[
Expand All @@ -931,6 +939,7 @@ foreach gen_mtargets : [
AVX512_SKX, AVX2, SSE2,
VSX2,
VXE, VX,
RVV,
]
],
[
Expand Down Expand Up @@ -962,7 +971,8 @@ foreach gen_mtargets : [
ASIMD, NEON,
AVX512_SKX, AVX2, SSE2,
VSX2,
VXE, VX
VXE, VX,
RVV,
]
],
[
Expand All @@ -972,7 +982,8 @@ foreach gen_mtargets : [
SSE41, SSE2,
VSX2,
ASIMD, NEON,
VXE, VX
VXE, VX,
RVV,
]
],
[
Expand All @@ -982,6 +993,7 @@ foreach gen_mtargets : [
SSE41, SSE2,
VSX2,
ASIMD, NEON,
RVV,
]
],
[
Expand All @@ -992,6 +1004,7 @@ foreach gen_mtargets : [
ASIMD, NEON,
VSX3, VSX2,
VXE, VX,
RVV,
]
],
[
Expand All @@ -1002,6 +1015,7 @@ foreach gen_mtargets : [
NEON,
VSX2,
VX,
RVV,
]
],
]
Expand All @@ -1018,7 +1032,8 @@ foreach gen_mtargets : [
'src/common',
'src/multiarray',
'src/npymath',
'src/umath'
'src/umath',
'src/neon2rvv',
]
)
if not is_variable('multiarray_umath_mtargets')
Expand Down Expand Up @@ -1214,6 +1229,7 @@ py.extension_module('_multiarray_umath',
'src/multiarray',
'src/npymath',
'src/umath',
'src/neon2rvv',
],
dependencies: [blas_dep],
link_with: [npymath_lib, multiarray_umath_mtargets.static_lib('_multiarray_umath_mtargets')] + highway_lib,
Expand Down
21 changes: 13 additions & 8 deletions numpy/_core/src/common/simd/intdiv.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@
#ifdef _MSC_VER
#include <intrin.h> // _BitScanReverse
#endif

#if defined(NPY_HAVE_RVV)
#define val __val
#endif

NPY_FINLINE unsigned npyv__bitscan_revnz_u32(npy_uint32 a)
{
assert(a > 0); // due to use __builtin_clz
Expand Down Expand Up @@ -212,7 +217,7 @@ NPY_FINLINE npyv_u8x3 npyv_divisor_u8(npy_uint8 d)
divisor.val[0] = npyv_setall_u8(m);
divisor.val[1] = npyv_setall_u8(sh1);
divisor.val[2] = npyv_setall_u8(sh2);
#elif defined(NPY_HAVE_NEON)
#elif defined(NPY_HAVE_NEON) || defined(NPY_HAVE_RVV)
divisor.val[0] = npyv_setall_u8(m);
divisor.val[1] = npyv_reinterpret_u8_s8(npyv_setall_s8(-sh1));
divisor.val[2] = npyv_reinterpret_u8_s8(npyv_setall_s8(-sh2));
Expand Down Expand Up @@ -251,7 +256,7 @@ NPY_FINLINE npyv_s8x3 npyv_divisor_s8(npy_int8 d)
divisor.val[2] = npyv_setall_s8(d < 0 ? -1 : 0);
#if defined(NPY_HAVE_VSX2) || defined(NPY_HAVE_VX)
divisor.val[1] = npyv_setall_s8(sh);
#elif defined(NPY_HAVE_NEON)
#elif defined(NPY_HAVE_NEON) || defined(NPY_HAVE_RVV)
divisor.val[1] = npyv_setall_s8(-sh);
#else
#error "please initialize the shifting operand for the new architecture"
Expand Down Expand Up @@ -288,7 +293,7 @@ NPY_FINLINE npyv_u16x3 npyv_divisor_u16(npy_uint16 d)
#elif defined(NPY_HAVE_VSX2) || defined(NPY_HAVE_VX)
divisor.val[1] = npyv_setall_u16(sh1);
divisor.val[2] = npyv_setall_u16(sh2);
#elif defined(NPY_HAVE_NEON)
#elif defined(NPY_HAVE_NEON) || defined(NPY_HAVE_RVV)
divisor.val[1] = npyv_reinterpret_u16_s16(npyv_setall_s16(-sh1));
divisor.val[2] = npyv_reinterpret_u16_s16(npyv_setall_s16(-sh2));
#else
Expand Down Expand Up @@ -319,7 +324,7 @@ NPY_FINLINE npyv_s16x3 npyv_divisor_s16(npy_int16 d)
divisor.val[1] = npyv_set_s16(sh);
#elif defined(NPY_HAVE_VSX2) || defined(NPY_HAVE_VX)
divisor.val[1] = npyv_setall_s16(sh);
#elif defined(NPY_HAVE_NEON)
#elif defined(NPY_HAVE_NEON) || defined(NPY_HAVE_RVV)
divisor.val[1] = npyv_setall_s16(-sh);
#else
#error "please initialize the shifting operand for the new architecture"
Expand Down Expand Up @@ -355,7 +360,7 @@ NPY_FINLINE npyv_u32x3 npyv_divisor_u32(npy_uint32 d)
#elif defined(NPY_HAVE_VSX2) || defined(NPY_HAVE_VX)
divisor.val[1] = npyv_setall_u32(sh1);
divisor.val[2] = npyv_setall_u32(sh2);
#elif defined(NPY_HAVE_NEON)
#elif defined(NPY_HAVE_NEON) || defined(NPY_HAVE_RVV)
divisor.val[1] = npyv_reinterpret_u32_s32(npyv_setall_s32(-sh1));
divisor.val[2] = npyv_reinterpret_u32_s32(npyv_setall_s32(-sh2));
#else
Expand Down Expand Up @@ -391,7 +396,7 @@ NPY_FINLINE npyv_s32x3 npyv_divisor_s32(npy_int32 d)
divisor.val[1] = npyv_set_s32(sh);
#elif defined(NPY_HAVE_VSX2) || defined(NPY_HAVE_VX)
divisor.val[1] = npyv_setall_s32(sh);
#elif defined(NPY_HAVE_NEON)
#elif defined(NPY_HAVE_NEON) || defined(NPY_HAVE_RVV)
divisor.val[1] = npyv_setall_s32(-sh);
#else
#error "please initialize the shifting operand for the new architecture"
Expand All @@ -402,7 +407,7 @@ NPY_FINLINE npyv_s32x3 npyv_divisor_s32(npy_int32 d)
NPY_FINLINE npyv_u64x3 npyv_divisor_u64(npy_uint64 d)
{
npyv_u64x3 divisor;
#if defined(NPY_HAVE_VSX2) || defined(NPY_HAVE_VX) || defined(NPY_HAVE_NEON)
#if defined(NPY_HAVE_VSX2) || defined(NPY_HAVE_VX) || defined(NPY_HAVE_NEON) || defined(NPY_HAVE_RVV)
divisor.val[0] = npyv_setall_u64(d);
#else
npy_uint64 l, l2, sh1, sh2, m;
Expand Down Expand Up @@ -437,7 +442,7 @@ NPY_FINLINE npyv_u64x3 npyv_divisor_u64(npy_uint64 d)
NPY_FINLINE npyv_s64x3 npyv_divisor_s64(npy_int64 d)
{
npyv_s64x3 divisor;
#if defined(NPY_HAVE_VSX2) || defined(NPY_HAVE_VX) || defined(NPY_HAVE_NEON)
#if defined(NPY_HAVE_VSX2) || defined(NPY_HAVE_VX) || defined(NPY_HAVE_NEON) || defined(NPY_HAVE_RVV)
divisor.val[0] = npyv_setall_s64(d);
divisor.val[1] = npyv_cvt_s64_b64(
npyv_cmpeq_s64(npyv_setall_s64(-1), divisor.val[0])
Expand Down
16 changes: 8 additions & 8 deletions numpy/_core/src/common/simd/neon/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ NPY_FINLINE npyv_u8 npyv_divc_u8(npyv_u8 a, const npyv_u8x3 divisor)
const uint8x8_t mulc_lo = vget_low_u8(divisor.val[0]);
// high part of unsigned multiplication
uint16x8_t mull_lo = vmull_u8(vget_low_u8(a), mulc_lo);
#if NPY_SIMD_F64
#if NPY_SIMD_F64 && !defined(__riscv)
uint16x8_t mull_hi = vmull_high_u8(a, divisor.val[0]);
// get the high unsigned bytes
uint8x16_t mulhi = vuzp2q_u8(vreinterpretq_u8_u16(mull_lo), vreinterpretq_u8_u16(mull_hi));
Expand All @@ -92,7 +92,7 @@ NPY_FINLINE npyv_s8 npyv_divc_s8(npyv_s8 a, const npyv_s8x3 divisor)
const int8x8_t mulc_lo = vget_low_s8(divisor.val[0]);
// high part of signed multiplication
int16x8_t mull_lo = vmull_s8(vget_low_s8(a), mulc_lo);
#if NPY_SIMD_F64
#if NPY_SIMD_F64 && !defined(__riscv)
int16x8_t mull_hi = vmull_high_s8(a, divisor.val[0]);
// get the high unsigned bytes
int8x16_t mulhi = vuzp2q_s8(vreinterpretq_s8_s16(mull_lo), vreinterpretq_s8_s16(mull_hi));
Expand All @@ -114,7 +114,7 @@ NPY_FINLINE npyv_u16 npyv_divc_u16(npyv_u16 a, const npyv_u16x3 divisor)
const uint16x4_t mulc_lo = vget_low_u16(divisor.val[0]);
// high part of unsigned multiplication
uint32x4_t mull_lo = vmull_u16(vget_low_u16(a), mulc_lo);
#if NPY_SIMD_F64
#if NPY_SIMD_F64 && !defined(__riscv)
uint32x4_t mull_hi = vmull_high_u16(a, divisor.val[0]);
// get the high unsigned bytes
uint16x8_t mulhi = vuzp2q_u16(vreinterpretq_u16_u32(mull_lo), vreinterpretq_u16_u32(mull_hi));
Expand All @@ -136,7 +136,7 @@ NPY_FINLINE npyv_s16 npyv_divc_s16(npyv_s16 a, const npyv_s16x3 divisor)
const int16x4_t mulc_lo = vget_low_s16(divisor.val[0]);
// high part of signed multiplication
int32x4_t mull_lo = vmull_s16(vget_low_s16(a), mulc_lo);
#if NPY_SIMD_F64
#if NPY_SIMD_F64 && !defined(__riscv)
int32x4_t mull_hi = vmull_high_s16(a, divisor.val[0]);
// get the high unsigned bytes
int16x8_t mulhi = vuzp2q_s16(vreinterpretq_s16_s32(mull_lo), vreinterpretq_s16_s32(mull_hi));
Expand All @@ -158,7 +158,7 @@ NPY_FINLINE npyv_u32 npyv_divc_u32(npyv_u32 a, const npyv_u32x3 divisor)
const uint32x2_t mulc_lo = vget_low_u32(divisor.val[0]);
// high part of unsigned multiplication
uint64x2_t mull_lo = vmull_u32(vget_low_u32(a), mulc_lo);
#if NPY_SIMD_F64
#if NPY_SIMD_F64 && !defined(__riscv)
uint64x2_t mull_hi = vmull_high_u32(a, divisor.val[0]);
// get the high unsigned bytes
uint32x4_t mulhi = vuzp2q_u32(vreinterpretq_u32_u64(mull_lo), vreinterpretq_u32_u64(mull_hi));
Expand All @@ -180,7 +180,7 @@ NPY_FINLINE npyv_s32 npyv_divc_s32(npyv_s32 a, const npyv_s32x3 divisor)
const int32x2_t mulc_lo = vget_low_s32(divisor.val[0]);
// high part of signed multiplication
int64x2_t mull_lo = vmull_s32(vget_low_s32(a), mulc_lo);
#if NPY_SIMD_F64
#if NPY_SIMD_F64 && !defined(__riscv)
int64x2_t mull_hi = vmull_high_s32(a, divisor.val[0]);
// get the high unsigned bytes
int32x4_t mulhi = vuzp2q_s32(vreinterpretq_s32_s64(mull_lo), vreinterpretq_s32_s64(mull_hi));
Expand Down Expand Up @@ -238,7 +238,7 @@ NPY_FINLINE npyv_s64 npyv_divc_s64(npyv_s64 a, const npyv_s64x3 divisor)
/***************************
* FUSED F32
***************************/
#ifdef NPY_HAVE_NEON_VFPV4 // FMA
#if defined (NPY_HAVE_NEON_VFPV4) || defined(NPY_HAVE_RVV) // FMA
// multiply and add, a*b + c
NPY_FINLINE npyv_f32 npyv_muladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
{ return vfmaq_f32(c, a, b); }
Expand Down Expand Up @@ -321,7 +321,7 @@ NPY_FINLINE npyv_f32 npyv_muladdsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
#endif

// expand the source vector and performs sum reduce
#if NPY_SIMD_F64
#if NPY_SIMD_F64 && !defined(__riscv)
#define npyv_sumup_u8 vaddlvq_u8
#define npyv_sumup_u16 vaddlvq_u16
#else
Expand Down
10 changes: 9 additions & 1 deletion numpy/_core/src/common/simd/neon/conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
#ifndef _NPY_SIMD_NEON_CVT_H
#define _NPY_SIMD_NEON_CVT_H

#if defined(NPY_HAVE_RVV)
#define val __val
#endif

// convert boolean vectors to integer vectors
#define npyv_cvt_u8_b8(A) A
#define npyv_cvt_s8_b8 vreinterpretq_s8_u8
Expand Down Expand Up @@ -38,6 +42,10 @@ NPY_FINLINE npy_uint64 npyv_tobits_b8(npyv_b8 a)
const npyv_u8 byteOrder = {0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15};
npyv_u8 v0 = vqtbl1q_u8(seq_scale, byteOrder);
return vaddlvq_u16(vreinterpretq_u16_u8(v0));
#elif defined(NPY_SIMD_F64)
npy_uint8 sumlo = vaddv_u8(vget_low_u8(seq_scale));
npy_uint8 sumhi = vaddv_u8(vget_high_u8(seq_scale));
return sumlo + ((int)sumhi << 8);
#else
npyv_u64 sumh = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(seq_scale)));
return vgetq_lane_u64(sumh, 0) + ((int)vgetq_lane_u64(sumh, 1) << 8);
Expand Down Expand Up @@ -128,7 +136,7 @@ npyv_pack_b8_b64(npyv_b64 a, npyv_b64 b, npyv_b64 c, npyv_b64 d,
}

// round to nearest integer
#if NPY_SIMD_F64
#if NPY_SIMD_F64 && !defined(__riscv)
#define npyv_round_s32_f32 vcvtnq_s32_f32
NPY_FINLINE npyv_s32 npyv_round_s32_f64(npyv_f64 a, npyv_f64 b)
{
Expand Down
Loading

0 comments on commit edee45e

Please sign in to comment.