From 4b91ef2f3a14519bf1c0ac25b25e1b95cbb35329 Mon Sep 17 00:00:00 2001 From: Arnaud De-Mattia Date: Mon, 4 Oct 2021 15:54:21 +0200 Subject: [PATCH] check for __FMA__ in case of SSE --- theory/DD/countpairs_kernels.c.src | 2 +- theory/DDrppi/countpairs_rp_pi_kernels.c.src | 2 +- theory/DDsmu/countpairs_s_mu_kernels.c.src | 2 +- theory/wp/wp_kernels.c.src | 2 +- theory/xi/xi_kernels.c.src | 2 +- utils/avx_calls.h | 8 ++++---- utils/sse_calls.h | 17 +++++++++++------ 7 files changed, 20 insertions(+), 15 deletions(-) diff --git a/theory/DD/countpairs_kernels.c.src b/theory/DD/countpairs_kernels.c.src index 150e91c2..237a5fc1 100644 --- a/theory/DD/countpairs_kernels.c.src +++ b/theory/DD/countpairs_kernels.c.src @@ -49,7 +49,7 @@ static inline int countpairs_avx512_intrinsics_DOUBLE(const int64_t N0, DOUBLE * } const int32_t need_rpavg = src_rpavg != NULL; const int32_t need_weightavg = src_weightavg != NULL; - const DOUBLE sqr_rpmin=rpmin*rpmin, sqr_rpmax=rpmax*rpmax; + const DOUBLE sqr_rpmax=rpmax*rpmax; AVX512_FLOATS m_inv_rpstep = AVX512_SETZERO_FLOAT(); AVX512_FLOATS m_rpmin_invstep = AVX512_SETZERO_FLOAT(); if (bin_type == BIN_LIN) { diff --git a/theory/DDrppi/countpairs_rp_pi_kernels.c.src b/theory/DDrppi/countpairs_rp_pi_kernels.c.src index b8f7e53e..8c98742f 100644 --- a/theory/DDrppi/countpairs_rp_pi_kernels.c.src +++ b/theory/DDrppi/countpairs_rp_pi_kernels.c.src @@ -43,7 +43,7 @@ static inline int countpairs_rp_pi_avx512_intrinsics_DOUBLE(const int64_t N0, DO const int32_t need_rpavg = src_rpavg != NULL; const int32_t need_weightavg = src_weightavg != NULL; - const DOUBLE sqr_rpmin=rpmin*rpmin, sqr_rpmax=rpmax*rpmax; + const DOUBLE sqr_rpmax=rpmax*rpmax; AVX512_FLOATS m_inv_rpstep = AVX512_SETZERO_FLOAT(); AVX512_FLOATS m_rpmin_invstep = AVX512_SETZERO_FLOAT(); if (bin_type == BIN_LIN) { diff --git a/theory/DDsmu/countpairs_s_mu_kernels.c.src b/theory/DDsmu/countpairs_s_mu_kernels.c.src index 585e08de..a6e092e1 100644 --- a/theory/DDsmu/countpairs_s_mu_kernels.c.src +++ b/theory/DDsmu/countpairs_s_mu_kernels.c.src @@ -43,7 +43,7 @@ static inline int countpairs_s_mu_avx512_intrinsics_DOUBLE(const int64_t N0, DOU const int32_t need_savg = src_savg != NULL; const int32_t need_weightavg = src_weightavg != NULL; - const DOUBLE sqr_smin=smin*smin, sqr_smax=smax*smax; + const DOUBLE sqr_smax=smax*smax; AVX512_FLOATS m_inv_sstep = AVX512_SETZERO_FLOAT(); AVX512_FLOATS m_smin_invstep = AVX512_SETZERO_FLOAT(); if (bin_type == BIN_LIN) { diff --git a/theory/wp/wp_kernels.c.src b/theory/wp/wp_kernels.c.src index 9291ffe4..50c1ad55 100644 --- a/theory/wp/wp_kernels.c.src +++ b/theory/wp/wp_kernels.c.src @@ -48,7 +48,7 @@ static inline int wp_avx512_intrinsics_DOUBLE(DOUBLE *x0, DOUBLE *y0, DOUBLE *z0 } const int32_t need_rpavg = src_rpavg != NULL; const int32_t need_weightavg = src_weightavg != NULL; - const DOUBLE sqr_rpmin=rpmin*rpmin, sqr_rpmax=rpmax*rpmax; + const DOUBLE sqr_rpmax=rpmax*rpmax; AVX512_FLOATS m_inv_rpstep = AVX512_SETZERO_FLOAT(); AVX512_FLOATS m_rpmin_invstep = AVX512_SETZERO_FLOAT(); if (bin_type == BIN_LIN) { diff --git a/theory/xi/xi_kernels.c.src b/theory/xi/xi_kernels.c.src index ddd748a8..5ecdee07 100644 --- a/theory/xi/xi_kernels.c.src +++ b/theory/xi/xi_kernels.c.src @@ -50,7 +50,7 @@ static inline int xi_avx512_intrinsics_DOUBLE(DOUBLE *x1, DOUBLE *y1, DOUBLE *z1 } const int32_t need_rpavg = src_rpavg != NULL; const int32_t need_weightavg = src_weightavg != NULL; - const DOUBLE sqr_rmin=rmin*rmin, sqr_rmax=rmax*rmax; + const DOUBLE sqr_rmax=rmax*rmax; AVX512_FLOATS m_inv_rpstep = AVX512_SETZERO_FLOAT(); AVX512_FLOATS m_rpmin_invstep = AVX512_SETZERO_FLOAT(); if (bin_type == BIN_LIN) { diff --git a/utils/avx_calls.h b/utils/avx_calls.h index 830b93f0..ba3e202e 100644 --- a/utils/avx_calls.h +++ b/utils/avx_calls.h @@ -63,7 +63,7 @@ extern "C" { /* returns Z + XY*/ #define AVX_FMA_ADD_FLOATS(X,Y,Z) _mm256_fmadd_ps(X,Y,Z) -#define AVX_FMA_ADD_TRUNCATE_FLOATS(X,Y,Z) _mm256_round_ps(_mm256_fmadd_ps(X,Y,Z),_MM_FROUND_TO_ZERO|_MM_FROUND_NO_EXC) +#define AVX_FMA_ADD_TRUNCATE_FLOATS(X,Y,Z) _mm256_round_ps(AVX_FMA_ADD_FLOATS(X,Y,Z),_MM_FROUND_TO_ZERO|_MM_FROUND_NO_EXC) // X OP Y #define AVX_COMPARE_FLOATS(X,Y,OP) _mm256_cmp_ps(X,Y,OP) @@ -82,11 +82,11 @@ extern "C" { #ifdef __INTEL_COMPILER #define AVX_ARC_COSINE(X, order) _mm256_acos_ps(X) #else - //Other compilers do not have the vectorized arc-cosine +//Other compilers do not have the vectorized arc-cosine #define AVX_ARC_COSINE(X, order) inv_cosine_avx(X, order) #endif - //Max +//Max #define AVX_MAX_FLOATS(X,Y) _mm256_max_ps(X,Y) @@ -128,7 +128,7 @@ extern "C" { /* returns Z + XY*/ #define AVX_FMA_ADD_FLOATS(X,Y,Z) _mm256_fmadd_pd(X,Y,Z) -#define AVX_FMA_ADD_TRUNCATE_FLOATS(X,Y,Z) _mm256_round_pd(_mm256_fmadd_pd(X,Y,Z),_MM_FROUND_TO_ZERO|_MM_FROUND_NO_EXC) +#define AVX_FMA_ADD_TRUNCATE_FLOATS(X,Y,Z) _mm256_round_pd(AVX_FMA_ADD_FLOATS(X,Y,Z),_MM_FROUND_TO_ZERO|_MM_FROUND_NO_EXC) // X OP Y #define AVX_COMPARE_FLOATS(X,Y,OP) _mm256_cmp_pd(X,Y,OP) diff --git a/utils/sse_calls.h b/utils/sse_calls.h index b152f842..bf3246ee 100644 --- a/utils/sse_calls.h +++ b/utils/sse_calls.h @@ -61,10 +61,12 @@ extern "C" { #define SSE_ABS_FLOAT(X) _mm_max_ps(_mm_sub_ps(_mm_setzero_ps(), X), X) /* returns Z + XY*/ +#ifdef __FMA__ #define SSE_FMA_ADD_FLOATS(X,Y,Z) _mm_fmadd_ps(X,Y,Z) -//#define SSE_FMA_ADD_TRUNCATE_FLOATS(X,Y,Z) _mm_fmadd_round_ss(X,Y,Z,_MM_FROUND_TO_ZERO|_MM_FROUND_NO_EXC) -#define SSE_FMA_ADD_TRUNCATE_FLOATS(X,Y,Z) _mm_round_ps(_mm_fmadd_ps(X,Y,Z),_MM_FROUND_TO_ZERO|_MM_FROUND_NO_EXC) - +#else +#define SSE_FMA_ADD_FLOATS(X,Y,Z) _mm_add_ps(_mm_mul_ps(X,Y),Z) +#endif +#define SSE_FMA_ADD_TRUNCATE_FLOATS(X,Y,Z) _mm_round_ps(SSE_FMA_ADD_FLOATS(X,Y,Z),_MM_FROUND_TO_ZERO|_MM_FROUND_NO_EXC) #ifdef __INTEL_COMPILER #define SSE_ARC_COSINE(X, order) _mm_acos_ps(X) @@ -122,9 +124,12 @@ extern "C" { #define SSE_ABS_FLOAT(X) _mm_max_pd(_mm_sub_pd(_mm_setzero_pd(), X), X) /* returns Z + XY*/ -#define SSE_FMA_ADD_FLOATS(X,Y,Z) _mm_fmadd_pd(X,Y,Z) -//#define SSE_FMA_ADD_TRUNCATE_FLOATS(X,Y,Z) _mm_fmadd_round_sd(X,Y,Z,_MM_FROUND_TO_ZERO|_MM_FROUND_NO_EXC) -#define SSE_FMA_ADD_TRUNCATE_FLOATS(X,Y,Z) _mm_round_pd(_mm_fmadd_pd(X,Y,Z),_MM_FROUND_TO_ZERO|_MM_FROUND_NO_EXC) +#ifdef __FMA__ +#define SSE_FMA_ADD_FLOATS(X,Y,Z) _mm_fmadd_pd(X,Y,Z) +#else +#define SSE_FMA_ADD_FLOATS(X,Y,Z) _mm_add_pd(_mm_mul_pd(X,Y),Z) +#endif +#define SSE_FMA_ADD_TRUNCATE_FLOATS(X,Y,Z) _mm_round_pd(SSE_FMA_ADD_FLOATS(X,Y,Z),_MM_FROUND_TO_ZERO|_MM_FROUND_NO_EXC) #endif