Skip to content

Commit

Permalink
Disable FMA by default. Use -Ofma or jit.opt.start("+fma") to enable.
Browse files Browse the repository at this point in the history
See the discussion in the corresponding ticket for the rationale.

(cherry picked from commit de2e1ca)

For the modulo operation, the arm64 VM uses `fmsub` [1] instruction,
which is the fused multiply-add (FMA [2]) operation (more precisely,
multiply-sub). Hence, it may produce different results compared to the
unfused one. This patch fixes the behaviour by using the unfused
instructions by default. However, the new JIT optimization flag (fma) is
introduced to make it possible to take advantage of the FMA
optimizations.

Sergey Kaplun:
* added the description and the test for the problem

[1]: https://developer.arm.com/documentation/dui0801/g/A64-Floating-point-Instructions/FMSUB
[2]: https://en.wikipedia.org/wiki/Multiply%E2%80%93accumulate_operation

Part of tarantool/tarantool#10709

Reviewed-by: Sergey Bronnikov <[email protected]>
Signed-off-by: Sergey Kaplun <[email protected]>
(cherry picked from commit 58b013a)
  • Loading branch information
Mike Pall authored and Buristan committed Jan 20, 2025
1 parent 73674ed commit 840a1f4
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 6 deletions.
8 changes: 8 additions & 0 deletions doc/running.html
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,12 @@ <h3 id="opt_O"><tt>-O[level]</tt><br>
overrides all earlier flags.
</p>
<p>
Note that <tt>-Ofma</tt> is not enabled by default at any level,
because it affects floating-point result accuracy. Only enable this,
if you fully understand the trade-offs of FMA for performance (higher),
determinism (lower) and numerical accuracy (higher).
</p>
<p>
Here are the available flags and at what optimization levels they
are enabled:
</p>
Expand Down Expand Up @@ -257,6 +263,8 @@ <h3 id="opt_O"><tt>-O[level]</tt><br>
<td class="flag_name">sink</td><td class="flag_level">&nbsp;</td><td class="flag_level">&nbsp;</td><td class="flag_level">&bull;</td><td class="flag_desc">Allocation/Store Sinking</td></tr>
<tr class="even">
<td class="flag_name">fuse</td><td class="flag_level">&nbsp;</td><td class="flag_level">&nbsp;</td><td class="flag_level">&bull;</td><td class="flag_desc">Fusion of operands into instructions</td></tr>
<tr class="odd">
<td class="flag_name">fma </td><td class="flag_level">&nbsp;</td><td class="flag_level">&nbsp;</td><td class="flag_level">&nbsp;</td><td class="flag_desc">Fused multiply-add</td></tr>
</table>
<p>
Here are the parameters and their default settings:
Expand Down
6 changes: 5 additions & 1 deletion src/lj_asm_arm.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,11 @@ static void asm_fusexref(ASMState *as, ARMIns ai, Reg rd, IRRef ref,
}

#if !LJ_SOFTFP
/* Fuse to multiply-add/sub instruction. */
/*
** Fuse to multiply-add/sub instruction.
** VMLA rounds twice (UMA, not FMA) -- no need to check for JIT_F_OPT_FMA.
** VFMA needs VFPv4, which is uncommon on the remaining ARM32 targets.
*/
static int asm_fusemadd(ASMState *as, IRIns *ir, ARMIns ai, ARMIns air)
{
IRRef lref = ir->op1, rref = ir->op2;
Expand Down
3 changes: 2 additions & 1 deletion src/lj_asm_arm64.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,8 @@ static int asm_fusemadd(ASMState *as, IRIns *ir, A64Ins ai, A64Ins air)
{
IRRef lref = ir->op1, rref = ir->op2;
IRIns *irm;
if (lref != rref &&
if ((as->flags & JIT_F_OPT_FMA) &&
lref != rref &&
((mayfuse(as, lref) && (irm = IR(lref), irm->o == IR_MUL) &&
ra_noreg(irm->r)) ||
(mayfuse(as, rref) && (irm = IR(rref), irm->o == IR_MUL) &&
Expand Down
3 changes: 2 additions & 1 deletion src/lj_asm_ppc.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ static int asm_fusemadd(ASMState *as, IRIns *ir, PPCIns pi, PPCIns pir)
{
IRRef lref = ir->op1, rref = ir->op2;
IRIns *irm;
if (lref != rref &&
if ((as->flags & JIT_F_OPT_FMA) &&
lref != rref &&
((mayfuse(as, lref) && (irm = IR(lref), irm->o == IR_MUL) &&
ra_noreg(irm->r)) ||
(mayfuse(as, rref) && (irm = IR(rref), irm->o == IR_MUL) &&
Expand Down
4 changes: 3 additions & 1 deletion src/lj_jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,11 @@
#define JIT_F_OPT_ABC (JIT_F_OPT << 7)
#define JIT_F_OPT_SINK (JIT_F_OPT << 8)
#define JIT_F_OPT_FUSE (JIT_F_OPT << 9)
#define JIT_F_OPT_FMA (JIT_F_OPT << 10)

/* Optimizations names for -O. Must match the order above. */
#define JIT_F_OPTSTRING \
"\4fold\3cse\3dce\3fwd\3dse\6narrow\4loop\3abc\4sink\4fuse"
"\4fold\3cse\3dce\3fwd\3dse\6narrow\4loop\3abc\4sink\4fuse\3fma"

/* Optimization levels set a fixed combination of flags. */
#define JIT_F_OPT_0 0
Expand All @@ -98,6 +99,7 @@
#define JIT_F_OPT_3 (JIT_F_OPT_2|\
JIT_F_OPT_FWD|JIT_F_OPT_DSE|JIT_F_OPT_ABC|JIT_F_OPT_SINK|JIT_F_OPT_FUSE)
#define JIT_F_OPT_DEFAULT JIT_F_OPT_3
/* Note: FMA is not set by default. */

/* -- JIT engine parameters ----------------------------------------------- */

Expand Down
13 changes: 12 additions & 1 deletion src/lj_vmmath.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,25 @@ LJ_FUNCA double lj_wrap_fmod(double x, double y) { return fmod(x, y); }

/* -- Helper functions ---------------------------------------------------- */

/* Required to prevent the C compiler from applying FMA optimizations.
**
** Yes, there's -ffp-contract and the FP_CONTRACT pragma ... in theory.
** But the current state of C compilers is a mess in this regard.
** Also, this function is not performance sensitive at all.
*/
LJ_NOINLINE static double lj_vm_floormul(double x, double y)
{
return lj_vm_floor(x / y) * y;
}

double lj_vm_foldarith(double x, double y, int op)
{
switch (op) {
case IR_ADD - IR_ADD: return x+y; break;
case IR_SUB - IR_ADD: return x-y; break;
case IR_MUL - IR_ADD: return x*y; break;
case IR_DIV - IR_ADD: return x/y; break;
case IR_MOD - IR_ADD: return x-lj_vm_floor(x/y)*y; break;
case IR_MOD - IR_ADD: return x-lj_vm_floormul(x, y); break;
case IR_POW - IR_ADD: return pow(x, y); break;
case IR_NEG - IR_ADD: return -x; break;
case IR_ABS - IR_ADD: return fabs(x); break;
Expand Down
4 changes: 3 additions & 1 deletion src/vm_arm64.dasc
Original file line number Diff line number Diff line change
Expand Up @@ -2581,7 +2581,9 @@ static void build_ins(BuildCtx *ctx, BCOp op, int defop)
|.macro ins_arithmod, res, reg1, reg2
| fdiv d2, reg1, reg2
| frintm d2, d2
| fmsub res, d2, reg2, reg1
| // Cannot use fmsub, because FMA is not enabled by default.
| fmul d2, d2, reg2
| fsub res, reg1, d2
|.endmacro
|
|.macro ins_arithdn, intins, fpins
Expand Down
48 changes: 48 additions & 0 deletions test/tarantool-tests/lj-918-fma-numerical-accuracy-jit.test.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
local tap = require('tap')

-- Test file to demonstrate consistent behaviour for JIT and the
-- VM regarding FMA optimization (disabled by default).
-- XXX: The VM behaviour is checked in the
-- <lj-918-fma-numerical-accuracy.test.lua>.
-- See also: https://github.com/LuaJIT/LuaJIT/issues/918.
local test = tap.test('lj-918-fma-numerical-accuracy-jit'):skipcond({
['Test requires JIT enabled'] = not jit.status(),
})

test:plan(1)

local _2pow52 = 2 ^ 52

-- XXX: Before this commit the LuaJIT arm64 VM uses `fmsub` [1]
-- instruction for the modulo operation, which is the fused
-- multiply-add (FMA [2]) operation (more precisely,
-- multiply-sub). Hence, it may produce different results compared
-- to the unfused one. For the test, let's just use 2 numbers in
-- modulo for which the single rounding is different from the
-- double rounding. The numbers from the original issue are good
-- enough.
--
-- [1]:https://developer.arm.com/documentation/dui0801/g/A64-Floating-point-Instructions/FMSUB
-- [2]:https://en.wikipedia.org/wiki/Multiply%E2%80%93accumulate_operation
--
-- IEEE754 components to double:
-- sign * (2 ^ (exp - 1023)) * (mantissa / _2pow52 + normal).
local a = 1 * (2 ^ (1083 - 1023)) * (4080546448249347 / _2pow52 + 1)
assert(a == 2197541395358679800)

local b = -1 * (2 ^ (1052 - 1023)) * (3927497732209973 / _2pow52 + 1)
assert(b == -1005065126.3690554)

local results = {}

jit.opt.start('hotloop=1')
for i = 1, 4 do
results[i] = a % b
end

-- XXX: The test doesn't fail before this commit. But it is
-- required to be sure that there are no inconsistencies after the
-- commit.
test:samevalues(results, 'consistent behaviour between the JIT and the VM')

test:done(true)
43 changes: 43 additions & 0 deletions test/tarantool-tests/lj-918-fma-numerical-accuracy.test.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
local tap = require('tap')

-- Test file to demonstrate possible numerical inaccuracy if FMA
-- optimization takes place.
-- XXX: The JIT consistency is checked in the
-- <lj-918-fma-numerical-accuracy-jit.test.lua>.
-- See also: https://github.com/LuaJIT/LuaJIT/issues/918.
local test = tap.test('lj-918-fma-numerical-accuracy')

test:plan(2)

local _2pow52 = 2 ^ 52

-- XXX: Before this commit the LuaJIT arm64 VM uses `fmsub` [1]
-- instruction for the modulo operation, which is the fused
-- multiply-add (FMA [2]) operation (more precisely,
-- multiply-sub). Hence, it may produce different results compared
-- to the unfused one. For the test, let's just use 2 numbers in
-- modulo for which the single rounding is different from the
-- double rounding. The numbers from the original issue are good
-- enough.
--
-- [1]:https://developer.arm.com/documentation/dui0801/g/A64-Floating-point-Instructions/FMSUB
-- [2]:https://en.wikipedia.org/wiki/Multiply%E2%80%93accumulate_operation
--
-- IEEE754 components to double:
-- sign * (2 ^ (exp - 1023)) * (mantissa / _2pow52 + normal).
local a = 1 * (2 ^ (1083 - 1023)) * (4080546448249347 / _2pow52 + 1)
assert(a == 2197541395358679800)

local b = -1 * (2 ^ (1052 - 1023)) * (3927497732209973 / _2pow52 + 1)
assert(b == -1005065126.3690554)

-- These tests fail on ARM64 before this patch or with FMA
-- optimization enabled.
-- The first test may not fail if the compiler doesn't generate
-- an ARM64 FMA operation in `lj_vm_foldarith()`.
test:is(2197541395358679800 % -1005065126.3690554, -606337536,
'FMA in the lj_vm_foldarith() during parsing')

test:is(a % b, -606337536, 'FMA in the VM')

test:done(true)
25 changes: 25 additions & 0 deletions test/tarantool-tests/lj-918-fma-optimization.test.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
local tap = require('tap')
local test = tap.test('lj-918-fma-optimization'):skipcond({
['Test requires JIT enabled'] = not jit.status(),
})

test:plan(3)

local function jit_opt_is_on(flag)
for _, opt in ipairs({jit.status()}) do
if opt == flag then
return true
end
end
return false
end

test:ok(not jit_opt_is_on('fma'), 'FMA is disabled by default')

local ok, _ = pcall(jit.opt.start, '+fma')

test:ok(ok, 'fma flag is recognized')

test:ok(jit_opt_is_on('fma'), 'FMA is enabled after jit.opt.start()')

test:done(true)

0 comments on commit 840a1f4

Please sign in to comment.