Skip to content

Commit

Permalink
Karatsuba multiplication now works
Browse files Browse the repository at this point in the history
  • Loading branch information
dlesnoff committed Jan 31, 2022
1 parent f3bcc9b commit bc118cb
Showing 1 changed file with 24 additions and 37 deletions.
61 changes: 24 additions & 37 deletions src/bigints.nim
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func initBigInt*(val: BigInt): BigInt =
const
zero = initBigInt(0)
one = initBigInt(1)
karatsubaTreshold = 5
karatsubaTreshold = 10

func isZero(a: BigInt): bool {.inline.} =
for i in countdown(a.limbs.high, 0):
Expand Down Expand Up @@ -418,35 +418,25 @@ func unsignedMultiplication(a: var BigInt, b, c: BigInt) {.inline.} =
inc pos
normalize(a)

func scalarMultiplication(a: var BigInt, b: uint32, c: BigInt) {.inline.} =
# Based on unsignedMultiplication
func scalarMultiplication(a: var BigInt, b: BigInt, c: uint32) {.inline.} =
# always called with bl >= cl
let
cl = c.limbs.len
a.limbs.setLen(1 + cl)
bl = b.limbs.len
a.limbs.setLen(bl + 1)
var tmp = 0'u64

tmp += uint64(b) * uint64(c.limbs[0])
a.limbs[1] = uint32(tmp and uint32.high)
tmp = tmp shr 32 # carry

a.limbs[1] = uint32(tmp)

for j in 1 ..< cl:
tmp = 0'u64
tmp += uint64(a.limbs[j]) + uint64(b) * uint64(c.limbs[j])
a.limbs[j] = uint32(tmp and uint32.high)
for i in 0 ..< bl:
tmp += uint64(b.limbs[i]) * uint64(c)
a.limbs[i] = uint32(tmp and uint32.high)
tmp = tmp shr 32
var pos = j + 1
while tmp > 0'u64:
tmp += uint64(a.limbs[pos])
a.limbs[pos] = uint32(tmp and uint32.high)
tmp = tmp shr 32
inc pos

a.limbs[bl] = uint32(tmp)
normalize(a)

# forward declaration for use in `multiplication`
func unsignedKaratsubaMultiplication(a: var BigInt, b, c: BigInt) {.inline.}
func karatsubaMultiplication(a: var BigInt, b, c: BigInt) {.inline.}
func `shl`*(x: BigInt, y: Natural): BigInt
func `shr`*(x: BigInt, y: Natural): BigInt

func multiplication(a: var BigInt, b, c: BigInt) =
# a = b * c
Expand All @@ -459,28 +449,27 @@ func multiplication(a: var BigInt, b, c: BigInt) =

if cl > bl:
if bl <= karatsubaTreshold:
unsignedKaratsubaMultiplication(a, c, b)
karatsubaMultiplication(a, c, b)
else:
unsignedMultiplication(a, c, b)
else:
if cl <= karatsubaTreshold:
unsignedKaratsubaMultiplication(a, b, c)
karatsubaMultiplication(a, b, c)
else:
unsignedMultiplication(a, b, c)
a.isNegative = b.isNegative xor c.isNegative

func `shr`*(x: BigInt, y: Natural): BigInt
func unsignedKaratsubaMultiplication(a: var BigInt, b, c: BigInt) {.inline.} =
func karatsubaMultiplication(a: var BigInt, b, c: BigInt) {.inline.} =
let
bl = b.limbs.len
cl = c.limbs.len
let n = max(bl, cl)
if bl == 1:
# base case : multiply the only limb with each limb of second term
scalarMultiplication(a, b.limbs[0], c)
scalarMultiplication(a, c, b.limbs[0])
return
if cl == 1:
scalarMultiplication(a, c.limbs[0], b)
scalarMultiplication(a, b, c.limbs[0])
return
if bl < karatsubaTreshold:
if cl <= bl:
Expand All @@ -507,21 +496,19 @@ func unsignedKaratsubaMultiplication(a: var BigInt, b, c: BigInt) {.inline.} =
# limit carry handling in opposition to the additive version
var
lowProduct, highProduct, A3, A4, A5, middleTerm: BigInt = zero
unsignedKaratsubaMultiplication(lowProduct, low_b, low_c)
unsignedKaratsubaMultiplication(highProduct, high_b, high_c)
karatsubaMultiplication(lowProduct, low_b, low_c)
karatsubaMultiplication(highProduct, high_b, high_c)
A3 = low_b - high_b # Additive variant of Karatsuba
A4 = high_c - low_c # would add them
A4 = low_c - high_c # would add them
if A4.limbs.len >= A3.limbs.len:
multiplication(A5, abs(A4), abs(A3))
else:
multiplication(A5, abs(A3), abs(A4))
middleTerm = lowProduct + highProduct + A5
a = lowProduct + (middleTerm shr k) + (highProduct shr (2*k))
# We could affect directly some of the bits of the result with slicing
# a.limbs[0 .. k - 1] = lowProduct.limbs
# But the following instructions would not be correct due to sign handling
# a.limbs[k .. 2*k-1] = middleTerm.limbs
# a.limbs[2*k .. 3*k-1] = highProduct.limbs
a.limbs[0 .. k - 1] = lowProduct.limbs
# a += (middleTerm shr k) + (highProduct shr (2*k))
a.limbs[k .. 2*k-1] = middleTerm.limbs
a.limbs[2*k .. 3*k-1] = highProduct.limbs


func `*`*(a, b: BigInt): BigInt =
Expand Down

0 comments on commit bc118cb

Please sign in to comment.