Skip to content

Commit

Permalink
refactor: option to turn on sse2 optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
bassmang committed Nov 17, 2023
1 parent 26b74fd commit 6883221
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ option(vw_BUILD_NET_FRAMEWORK "Build .NET Framework targets" OFF)
option(VW_USE_ASAN "Compile with AddressSanitizer" OFF)
option(VW_USE_UBSAN "Compile with UndefinedBehaviorSanitizer" OFF)
option(VW_BUILD_WASM "Add WASM target" OFF)
option(SSE2_OPT "Add SSE2 optimization" OFF)

if(VW_USE_ASAN)
add_compile_definitions(VW_USE_ASAN)
Expand Down
48 changes: 47 additions & 1 deletion vowpalwabbit/core/src/reductions/gd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,52 @@ void sync_weights(VW::workspace& all)
all.sd->contraction = 1.;
}

VW_WARNING_STATE_PUSH
VW_WARNING_DISABLE_UNUSED_FUNCTION
inline float quake_inv_sqrt(float x)
{
// Carmack/Quake/SGI fast method:
float xhalf = 0.5f * x;
static_assert(sizeof(int) == sizeof(float), "Floats and ints are converted between, they must be the same size.");
int i = reinterpret_cast<int&>(x); // store floating-point bits in integer
i = 0x5f3759d5 - (i >> 1); // initial guess for Newton's method
x = reinterpret_cast<float&>(i); // convert new bits into float
x = x * (1.5f - xhalf * x * x); // One round of Newton's method
return x;
}
VW_WARNING_STATE_POP

static inline float inv_sqrt(float x)
{
#if !defined(SSE2_OPT)
return 1.f / std::sqrt(x);
#endif
#if !defined(VW_NO_INLINE_SIMD)
# if defined(__ARM_NEON__)
// Propagate into vector
float32x2_t v1 = vdup_n_f32(x);
// Estimate
float32x2_t e1 = vrsqrte_f32(v1);
// N-R iteration 1
float32x2_t e2 = vmul_f32(e1, vrsqrts_f32(v1, vmul_f32(e1, e1)));
// N-R iteration 2
float32x2_t e3 = vmul_f32(e2, vrsqrts_f32(v1, vmul_f32(e2, e2)));
// Extract result
return vget_lane_f32(e3, 0);
# elif defined(__SSE2__)
__m128 eta = _mm_load_ss(&x);
eta = _mm_rsqrt_ss(eta);
_mm_store_ss(&x, eta);
# else
x = quake_inv_sqrt(x);
# endif
#else
x = quake_inv_sqrt(x);
#endif

return x;
}

VW_WARNING_STATE_PUSH
VW_WARNING_DISABLE_COND_CONST_EXPR
template <bool sqrt_rate, bool feature_mask_off, size_t adaptive, size_t normalized, size_t spare>
Expand Down Expand Up @@ -580,7 +626,7 @@ inline float compute_rate_decay(power_data& s, float& fw)
float rate_decay = 1.f;
if (adaptive)
{
if (sqrt_rate) { rate_decay = 1.0f / std::sqrt(w[adaptive]); }
if (sqrt_rate) { rate_decay = inv_sqrt(w[adaptive]); }
else { rate_decay = powf(w[adaptive], s.minus_power_t); }
}
if VW_STD17_CONSTEXPR (normalized != 0)
Expand Down

0 comments on commit 6883221

Please sign in to comment.