Skip to content

Commit

Permalink
Explore prefetching to try to fix #446
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Jul 31, 2024
1 parent 57911e8 commit cc7eba7
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 51 deletions.
2 changes: 1 addition & 1 deletion benchmarks/bench_blueprint.nim
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ proc warmup*() =
let stop = cpuTime()
echo &"Warmup: {stop - start:>4.4f} s, result {foo} (displayed to avoid compiler optimizing warmup away)\n"

warmup()
# warmup()

when defined(gcc):
echo "\nCompiled with GCC"
Expand Down
70 changes: 35 additions & 35 deletions benchmarks/bench_ec_g1.nim
Original file line number Diff line number Diff line change
Expand Up @@ -26,54 +26,54 @@ import
# ############################################################


const Iters = 10_000
const Iters = 10_000_000
const MulIters = 100
const AvailableCurves = [
# P224,
BN254_Nogami,
BN254_Snarks,
# BN254_Nogami,
# BN254_Snarks,
# Edwards25519,
# P256,
Secp256k1,
Pallas,
Vesta,
BLS12_377,
BLS12_381,
# Pallas,
# Vesta,
# BLS12_377,
# BLS12_381,
]

proc main() =
separator()
staticFor i, 0, AvailableCurves.len:
const curve = AvailableCurves[i]
addBench(EC_ShortW_Prj[Fp[curve], G1], Iters)
addBench(EC_ShortW_Jac[Fp[curve], G1], Iters)
addBench(EC_ShortW_JacExt[Fp[curve], G1], Iters)
mixedAddBench(EC_ShortW_Prj[Fp[curve], G1], Iters)
mixedAddBench(EC_ShortW_Jac[Fp[curve], G1], Iters)
mixedAddBench(EC_ShortW_JacExt[Fp[curve], G1], Iters)
doublingBench(EC_ShortW_Prj[Fp[curve], G1], Iters)
doublingBench(EC_ShortW_Jac[Fp[curve], G1], Iters)
doublingBench(EC_ShortW_JacExt[Fp[curve], G1], Iters)
separator()
affFromProjBench(EC_ShortW_Prj[Fp[curve], G1], MulIters)
affFromJacBench(EC_ShortW_Jac[Fp[curve], G1], MulIters)
separator()
for numPoints in [10, 100, 1000, 10000]:
let batchIters = max(1, Iters div numPoints)
affFromProjBatchBench(EC_ShortW_Prj[Fp[curve], G1], numPoints, useBatching = false, batchIters)
separator()
for numPoints in [10, 100, 1000, 10000]:
let batchIters = max(1, Iters div numPoints)
affFromProjBatchBench(EC_ShortW_Prj[Fp[curve], G1], numPoints, useBatching = true, batchIters)
separator()
for numPoints in [10, 100, 1000, 10000]:
let batchIters = max(1, Iters div numPoints)
affFromJacBatchBench(EC_ShortW_Jac[Fp[curve], G1], numPoints, useBatching = false, batchIters)
separator()
for numPoints in [10, 100, 1000, 10000]:
let batchIters = max(1, Iters div numPoints)
affFromJacBatchBench(EC_ShortW_Jac[Fp[curve], G1], numPoints, useBatching = true, batchIters)
separator()
# addBench(EC_ShortW_Jac[Fp[curve], G1], Iters)
# addBench(EC_ShortW_JacExt[Fp[curve], G1], Iters)
# mixedAddBench(EC_ShortW_Prj[Fp[curve], G1], Iters)
# mixedAddBench(EC_ShortW_Jac[Fp[curve], G1], Iters)
# mixedAddBench(EC_ShortW_JacExt[Fp[curve], G1], Iters)
# doublingBench(EC_ShortW_Prj[Fp[curve], G1], Iters)
# doublingBench(EC_ShortW_Jac[Fp[curve], G1], Iters)
# doublingBench(EC_ShortW_JacExt[Fp[curve], G1], Iters)
# separator()
# affFromProjBench(EC_ShortW_Prj[Fp[curve], G1], MulIters)
# affFromJacBench(EC_ShortW_Jac[Fp[curve], G1], MulIters)
# separator()
# for numPoints in [10, 100, 1000, 10000]:
# let batchIters = max(1, Iters div numPoints)
# affFromProjBatchBench(EC_ShortW_Prj[Fp[curve], G1], numPoints, useBatching = false, batchIters)
# separator()
# for numPoints in [10, 100, 1000, 10000]:
# let batchIters = max(1, Iters div numPoints)
# affFromProjBatchBench(EC_ShortW_Prj[Fp[curve], G1], numPoints, useBatching = true, batchIters)
# separator()
# for numPoints in [10, 100, 1000, 10000]:
# let batchIters = max(1, Iters div numPoints)
# affFromJacBatchBench(EC_ShortW_Jac[Fp[curve], G1], numPoints, useBatching = false, batchIters)
# separator()
# for numPoints in [10, 100, 1000, 10000]:
# let batchIters = max(1, Iters div numPoints)
# affFromJacBatchBench(EC_ShortW_Jac[Fp[curve], G1], numPoints, useBatching = true, batchIters)
# separator()
separator()

main()
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/bench_elliptic_template.nim
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ proc addBench*(EC: typedesc, iters: int) {.noinline.} =
block:
bench("EC Add " & $EC.G, EC, iters):
r.sum(P, Q)
block:
bench("EC Add vartime " & $EC.G, EC, iters):
r.sum_vartime(P, Q)
# block:
# bench("EC Add vartime " & $EC.G, EC, iters):
# r.sum_vartime(P, Q)

proc mixedAddBench*(EC: typedesc, iters: int) {.noinline.} =
var r {.noInit.}: EC
Expand Down
17 changes: 17 additions & 0 deletions constantine/math/arithmetic/assembly/limbs_asm_modular_x86.nim
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ proc finalSubNoOverflowImpl*(
if not a_in_scratch:
ctx.mov scratch[0], a[0]
ctx.sub scratch[0], M[0]
# Combat cache-misses
# https://github.com/mratsim/constantine/issues/446#issuecomment-2254258024
ctx.prefetchw r
for i in 1 ..< N:
if not a_in_scratch:
ctx.mov scratch[i], a[i]
Expand Down Expand Up @@ -75,6 +78,9 @@ proc finalSubMayOverflowImpl*(
if not a_in_scratch:
ctx.mov scratch[0], a[0]
ctx.sub scratch[0], M[0]
# Combat cache-misses
# https://github.com/mratsim/constantine/issues/446#issuecomment-2254258024
ctx.prefetchw r
for i in 1 ..< N:
if not a_in_scratch:
ctx.mov scratch[i], a[i]
Expand Down Expand Up @@ -157,6 +163,9 @@ macro addmod_gen[N: static int](r_PIR: var Limbs[N], a_PIR, b_PIR, M_MEM: Limbs[
# Addition
ctx.add u[0], b[0]
ctx.mov v[0], u[0]
# Combat cache-misses
# https://github.com/mratsim/constantine/issues/446#issuecomment-2254258024
ctx.prefetcht0 M
for i in 1 ..< N:
ctx.adc u[i], b[i]
# Interleaved copy in a second buffer as well
Expand Down Expand Up @@ -215,6 +224,10 @@ macro submod_gen[N: static int](r_PIR: var Limbs[N], a_PIR, b_PIR, M_MEM: Limbs[
let underflowed = b.reuseRegister()
ctx.sbb underflowed, underflowed

# Combat cache-misses
# https://github.com/mratsim/constantine/issues/446#issuecomment-2254258024
ctx.prefetchw r

# Now mask the adder, with 0 or the modulus limbs
for i in 0 ..< N:
ctx.`and` v[i], underflowed
Expand Down Expand Up @@ -265,6 +278,10 @@ macro negmod_gen[N: static int](r_PIR: var Limbs[N], a_MEM, M_MEM: Limbs[N]): un
ctx.mov u[i], M[i]
ctx.sbb u[i], a[i]

# Combat cache-misses
# https://github.com/mratsim/constantine/issues/446#issuecomment-2254258024
ctx.prefetchw r

# Deal with a == 0
ctx.mov isZero, a[0]
for i in 1 ..< N:
Expand Down
24 changes: 12 additions & 12 deletions constantine/math/arithmetic/limbs_crandall.nim
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ func mulCranPartialReduce[N: static int](
m: static int, c: static SecretWord) {.inline.} =
when UseASM_X86_64 and a.len in {3..6}:
# ADX implies BMI2
if ({.noSideEffect.}: hasAdx()):
# if ({.noSideEffect.}: hasAdx()):
r.mulCranPartialReduce_asm_adx(a, b, m, c)
else:
r.mulCranPartialReduce_asm(a, b, m, c)
# else:
# r.mulCranPartialReduce_asm(a, b, m, c)
else:
var r2 {.noInit.}: Limbs[2*N]
r2.prod(a, b)
Expand All @@ -208,10 +208,10 @@ func mulCran*[N: static int](
r.mulCranPartialReduce(a, b, m, c)
elif UseASM_X86_64 and a.len in {3..6}:
# ADX implies BMI2
if ({.noSideEffect.}: hasAdx()):
# if ({.noSideEffect.}: hasAdx()):
r.mulCran_asm_adx(a, b, p, m, c)
else:
r.mulCran_asm(a, b, p, m, c)
# else:
# r.mulCran_asm(a, b, p, m, c)
else:
var r2 {.noInit.}: Limbs[2*N]
r2.prod(a, b)
Expand All @@ -224,10 +224,10 @@ func squareCranPartialReduce[N: static int](
m: static int, c: static SecretWord) {.inline.} =
when UseASM_X86_64 and a.len in {3..6}:
# ADX implies BMI2
if ({.noSideEffect.}: hasAdx()):
# if ({.noSideEffect.}: hasAdx()):
r.squareCranPartialReduce_asm_adx(a, m, c)
else:
r.squareCranPartialReduce_asm(a, m, c)
# else:
# r.squareCranPartialReduce_asm(a, m, c)
else:
var r2 {.noInit.}: Limbs[2*N]
r2.square(a)
Expand All @@ -243,10 +243,10 @@ func squareCran*[N: static int](
r.squareCranPartialReduce(a, m, c)
elif UseASM_X86_64 and a.len in {3..6}:
# ADX implies BMI2
if ({.noSideEffect.}: hasAdx()):
# if ({.noSideEffect.}: hasAdx()):
r.squareCran_asm_adx(a, p, m, c)
else:
r.squareCran_asm(a, p, m, c)
# else:
# r.squareCran_asm(a, p, m, c)
else:
var r2 {.noInit.}: Limbs[2*N]
r2.square(a)
Expand Down
12 changes: 12 additions & 0 deletions constantine/platforms/x86/macro_assembler_x86_att.nim
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,18 @@ func setc*(a: var Assembler_x86, dst: Register) =
a.code &= "setc " & Reg8Low[dst] & '\n'
# No flags affected

func prefetcht0*(a: var Assembler_x86, mem: Operand or OperandArray) =
## Retrieve memory in all cache levels for reading
let loc = a.getStrOffset(mem[0])
a.code &= "prefetcht0 " & loc & '\n'
# No flags affected

func prefetchw*(a: var Assembler_x86, mem: Operand or OperandArray) =
## Retrieve memory in all cache levels for writing
let loc = a.getStrOffset(mem[0])
a.code &= "prefetchw " & loc & '\n'
# No flags affected

func add*(a: var Assembler_x86, dst, src: Operand) =
## Does: dst <- dst + src
doAssert dst.isOutput()
Expand Down
25 changes: 25 additions & 0 deletions constantine/platforms/x86/macro_assembler_x86_intel.nim
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,31 @@ func setc*(a: var Assembler_x86, dst: Register) =
a.code &= "setc " & Reg8Low[dst] & '\n'
# No flags affected

func getPrefetchLoc(mem: Operand or OperandArray): string =
let mem = mem[0]
if mem.desc.rm in {Mem, MemOffsettable}:
return "BYTE ptr %" & mem.desc.asmId
elif mem.desc.rm == PointerInReg or
mem.desc.rm in SpecificRegisters or
(mem.desc.rm == ElemsInReg and mem.kind == kFromArray):
return "BYTE ptr [%" & mem.desc.asmId & "]"
elif mem.desc.rm == ClobberedReg:
return "BYTE ptr [" & mem.desc.asmId & "]"
else:
error("Unsupported memory operand type for prefetch: " & mem.repr)

func prefetcht0*(a: var Assembler_x86, mem: Operand or OperandArray) =
## Retrieve memory in all cache levels for reading
let loc = getPrefetchLoc(mem)
a.code &= "prefetcht0 " & loc & '\n'
# No flags affected

func prefetchw*(a: var Assembler_x86, mem: Operand or OperandArray) =
## Retrieve memory in all cache levels for writing
let loc = getPrefetchLoc(mem)
a.code &= "prefetchw " & loc & '\n'
# No flags affected

func add*(a: var Assembler_x86, dst, src: Operand) =
## Does: dst <- dst + src
doAssert dst.isOutput()
Expand Down

0 comments on commit cc7eba7

Please sign in to comment.