Skip to content

Commit

Permalink
eccurve refactor draft
Browse files Browse the repository at this point in the history
  • Loading branch information
iaik-jheher committed Oct 28, 2024
1 parent 78a3742 commit 385f2fa
Show file tree
Hide file tree
Showing 5 changed files with 324 additions and 193 deletions.
351 changes: 210 additions & 141 deletions indispensable/src/commonMain/kotlin/at/asitplus/signum/ecmath/ECMath.kt
Original file line number Diff line number Diff line change
@@ -1,124 +1,204 @@
package at.asitplus.signum.ecmath

import at.asitplus.signum.indispensable.ECPoint
import at.asitplus.signum.indispensable.*
import com.ionspin.kotlin.bignum.integer.BigInteger
import com.ionspin.kotlin.bignum.modular.ModularBigInteger
import kotlin.math.max

interface ECMathImpl {
/** checked in ECMathTest.kt */
fun checkRequirements(curve: NewECCurve)

/** addition (not necessarily constant time) */
fun plus(v: ECPoint, w: ECPoint): ECPoint = ct_plus(v,w)
/** constant-time point addition */
fun ct_plus(v: ECPoint, w: ECPoint): ECPoint

/** adding a point to itself */
fun double(p: ECPoint): ECPoint = ct_double(p)
/** constant-time doubling */
fun ct_double(p: ECPoint): ECPoint

/** scalar multiplication (not necessarily constant time) */
fun mul(p: BigInteger, Q: ECPoint) = ct_mul(p,Q)
/** constant-time scalar multiplication */
fun ct_mul(p: BigInteger, Q: ECPoint): ECPoint
}

val NewECCurve.math: ECMathImpl inline get() = when (this) {
ECCurve.SECP_256_R_1, ECCurve.SECP_384_R_1, ECCurve.SECP_521_R_1 -> WeierstrassArithmeticForAEqualsMinus3
}

object WeierstrassArithmeticForAEqualsMinus3: ECMathImpl {
override fun checkRequirements(curve: NewECCurve) {
check(curve is WeierstrassCurve)
check(curve.a == BigInteger(-3).toModularBigInteger(curve.modulus))
}

/* Algorithm 4 from https://eprint.iacr.org/2015/1060.pdf */
override fun ct_plus(v: ECPoint, w: ECPoint): ECPoint {
val b = v.curve.b
val X1 = v.homX
val Y1 = v.homY
val Z1 = v.homZ
val X2 = w.homX
val Y2 = w.homY
val Z2 = w.homZ
var t0: ModularBigInteger
var t1: ModularBigInteger
var t2: ModularBigInteger
var t3: ModularBigInteger
var t4: ModularBigInteger
var X3: ModularBigInteger
var Y3: ModularBigInteger
var Z3: ModularBigInteger
/* 1. */ t0 = X1 * X2
/* 2. */ t1 = Y1 * Y2
/* 3. */ t2 = Z1 * Z2
/* 4. */ t3 = X1 + Y1
/* 5. */ t4 = X2 + Y2
/* 6. */ t3 = t3 * t4
/* 7. */ t4 = t0 + t1
/* 8. */ t3 = t3 - t4
/* 9. */ t4 = Y1 + Z1
/* 10. */ X3 = Y2 + Z2
/* 11. */ t4 = t4 * X3
/* 12. */ X3 = t1 + t2
/* 13. */ t4 = t4 - X3
/* 14. */ X3 = X1 + Z1
/* 15. */ Y3 = X2 + Z2
/* 16. */ X3 = X3 * Y3
/* 17. */ Y3 = t0 + t2
/* 18. */ Y3 = X3 - Y3
/* 19. */ Z3 = b * t2
/* 20. */ X3 = Y3 - Z3
/* 21. */ Z3 = X3 + X3
/* 22. */ X3 = X3 + Z3
/* 23. */ Z3 = t1 - X3
/* 24. */ X3 = t1 + X3
/* 25. */ Y3 = b * Y3
/* 26. */ t1 = t2 + t2
/* 27. */ t2 = t1 + t2
/* 28. */ Y3 = Y3 - t2
/* 29. */ Y3 = Y3 - t0
/* 30. */ t1 = Y3 + Y3
/* 31. */ Y3 = t1 + Y3
/* 32. */ t1 = t0 + t0
/* 33. */ t0 = t1 + t0
/* 34. */ t0 = t0 - t2
/* 35. */ t1 = t4 * Y3
/* 36. */ t2 = t0 * Y3
/* 37. */ Y3 = X3 * Z3
/* 38. */ Y3 = Y3 + t2
/* 39. */ X3 = t3 * X3
/* 40. */ X3 = X3 - t1
/* 41. */ Z3 = t4 * Z3
/* 42. */ t1 = t3 * t0
/* 43. */ Z3 = Z3 + t1
return ECPoint.General.unsafeFromXYZ(v.curve, X3, Y3, Z3)
}

/* Algorithm 6 from https://eprint.iacr.org/2015/1060.pdf */
override fun ct_double(p: ECPoint): ECPoint {
val b = p.curve.b
val X = p.homX
val Y = p.homY
val Z = p.homZ
var t0: ModularBigInteger
var t1: ModularBigInteger
var t2: ModularBigInteger
var t3: ModularBigInteger
var X3: ModularBigInteger
var Y3: ModularBigInteger
var Z3: ModularBigInteger
/* 1. */ t0 = X * X
/* 2. */ t1 = Y * Y
/* 3. */ t2 = Z * Z
/* 4. */ t3 = X * Y
/* 5. */ t3 = t3 + t3
/* 6. */ Z3 = X * Z
/* 7. */ Z3 = Z3 + Z3
/* 8. */ Y3 = b * t2
/* 9. */ Y3 = Y3 - Z3
/* 10. */ X3 = Y3 + Y3
/* 11. */ Y3 = X3 + Y3
/* 12. */ X3 = t1 - Y3
/* 13. */ Y3 = t1 + Y3
/* 14. */ Y3 = X3 * Y3
/* 15. */ X3 = X3 * t3
/* 16. */ t3 = t2 + t2
/* 17. */ t2 = t2 + t3
/* 18. */ Z3 = b * Z3
/* 19. */ Z3 = Z3 - t2
/* 20. */ Z3 = Z3 - t0
/* 21. */ t3 = Z3 + Z3
/* 22. */ Z3 = Z3 + t3
/* 23. */ t3 = t0 + t0
/* 24. */ t0 = t3 + t0
/* 25. */ t0 = t0 - t2
/* 26. */ t0 = t0 * Z3
/* 27. */ Y3 = Y3 + t0
/* 28. */ t0 = Y * Z
/* 29. */ t0 = t0 + t0
/* 30. */ Z3 = t0 * Z3
/* 31. */ X3 = X3 - Z3
/* 32. */ Z3 = t0 * t1
/* 33. */ Z3 = Z3 + Z3
/* 34. */ Z3 = Z3 + Z3
return ECPoint.General.unsafeFromXYZ(p.curve, X3, Y3, Z3)
}

// TODO: i'm sure this could be smarter (keyword: "comb")
override fun mul(p: BigInteger, Q: ECPoint): ECPoint {
var o = Q
var sum = if (p.bitAt(0)) Q else Q.curve.IDENTITY
/* double-and-add */
for (i in 1L..<p.bitLength()) {
/* we double o on each iteration (it is (2^i)*point) */
o = o.double()
/* and decide whether to add it based on the bit */
if (p.bitAt(i)) sum += o
}
return sum
}

override fun ct_mul(p: BigInteger, Q: ECPoint): ECPoint {
var R0 = Q
var R1 = Q.double()
var i = p.bitLength().toLong()-1
while (--i >= 0) {
if (p.bitAt(i)) {
R0 = (R0+R1)
R1 = R1.double()
} else {
R1 = (R0+R1)
R0 = R0.double()
}
}
return R0
}
}

/** adds `other` to `this` and returns the result */
/* Algorithm 4 from https://eprint.iacr.org/2015/1060.pdf */
operator fun ECPoint.plus(other: ECPoint): ECPoint {
inline operator fun ECPoint.plus(other: ECPoint): ECPoint {
require(this.curve == other.curve)
val b = this.curve.b
val X1 = this.homX
val Y1 = this.homY
val Z1 = this.homZ
val X2 = other.homX
val Y2 = other.homY
val Z2 = other.homZ
var t0: ModularBigInteger
var t1: ModularBigInteger
var t2: ModularBigInteger
var t3: ModularBigInteger
var t4: ModularBigInteger
var X3: ModularBigInteger
var Y3: ModularBigInteger
var Z3: ModularBigInteger
/* 1. */ t0 = X1 * X2
/* 2. */ t1 = Y1 * Y2
/* 3. */ t2 = Z1 * Z2
/* 4. */ t3 = X1 + Y1
/* 5. */ t4 = X2 + Y2
/* 6. */ t3 = t3 * t4
/* 7. */ t4 = t0 + t1
/* 8. */ t3 = t3 - t4
/* 9. */ t4 = Y1 + Z1
/* 10. */ X3 = Y2 + Z2
/* 11. */ t4 = t4 * X3
/* 12. */ X3 = t1 + t2
/* 13. */ t4 = t4 - X3
/* 14. */ X3 = X1 + Z1
/* 15. */ Y3 = X2 + Z2
/* 16. */ X3 = X3 * Y3
/* 17. */ Y3 = t0 + t2
/* 18. */ Y3 = X3 - Y3
/* 19. */ Z3 = b * t2
/* 20. */ X3 = Y3 - Z3
/* 21. */ Z3 = X3 + X3
/* 22. */ X3 = X3 + Z3
/* 23. */ Z3 = t1 - X3
/* 24. */ X3 = t1 + X3
/* 25. */ Y3 = b * Y3
/* 26. */ t1 = t2 + t2
/* 27. */ t2 = t1 + t2
/* 28. */ Y3 = Y3 - t2
/* 29. */ Y3 = Y3 - t0
/* 30. */ t1 = Y3 + Y3
/* 31. */ Y3 = t1 + Y3
/* 32. */ t1 = t0 + t0
/* 33. */ t0 = t1 + t0
/* 34. */ t0 = t0 - t2
/* 35. */ t1 = t4 * Y3
/* 36. */ t2 = t0 * Y3
/* 37. */ Y3 = X3 * Z3
/* 38. */ Y3 = Y3 + t2
/* 39. */ X3 = t3 * X3
/* 40. */ X3 = X3 - t1
/* 41. */ Z3 = t4 * Z3
/* 42. */ t1 = t3 * t0
/* 43. */ Z3 = Z3 + t1
return ECPoint.General.unsafeFromXYZ(curve, X3, Y3, Z3)
return this.curve.math.plus(this, other)
}

inline infix fun ECPoint.ct_plus(other: ECPoint): ECPoint {
require(this.curve == other.curve)
return this.curve.math.ct_plus(this, other)
}


/** adds `this` to `this` and returns the result */
/* Algorithm 6 from https://eprint.iacr.org/2015/1060.pdf */
fun ECPoint.double(): ECPoint {
val b = this.curve.b
val X = this.homX
val Y = this.homY
val Z = this.homZ
var t0: ModularBigInteger
var t1: ModularBigInteger
var t2: ModularBigInteger
var t3: ModularBigInteger
var X3: ModularBigInteger
var Y3: ModularBigInteger
var Z3: ModularBigInteger
/* 1. */ t0 = X * X
/* 2. */ t1 = Y * Y
/* 3. */ t2 = Z * Z
/* 4. */ t3 = X * Y
/* 5. */ t3 = t3 + t3
/* 6. */ Z3 = X * Z
/* 7. */ Z3 = Z3 + Z3
/* 8. */ Y3 = b * t2
/* 9. */ Y3 = Y3 - Z3
/* 10. */ X3 = Y3 + Y3
/* 11. */ Y3 = X3 + Y3
/* 12. */ X3 = t1 - Y3
/* 13. */ Y3 = t1 + Y3
/* 14. */ Y3 = X3 * Y3
/* 15. */ X3 = X3 * t3
/* 16. */ t3 = t2 + t2
/* 17. */ t2 = t2 + t3
/* 18. */ Z3 = b * Z3
/* 19. */ Z3 = Z3 - t2
/* 20. */ Z3 = Z3 - t0
/* 21. */ t3 = Z3 + Z3
/* 22. */ Z3 = Z3 + t3
/* 23. */ t3 = t0 + t0
/* 24. */ t0 = t3 + t0
/* 25. */ t0 = t0 - t2
/* 26. */ t0 = t0 * Z3
/* 27. */ Y3 = Y3 + t0
/* 28. */ t0 = Y * Z
/* 29. */ t0 = t0 + t0
/* 30. */ Z3 = t0 * Z3
/* 31. */ X3 = X3 - Z3
/* 32. */ Z3 = t0 * t1
/* 33. */ Z3 = Z3 + Z3
/* 34. */ Z3 = Z3 + Z3
return ECPoint.General.unsafeFromXYZ(curve, X3, Y3, Z3)
inline fun ECPoint.double(): ECPoint {
return this.curve.math.double(this)
}

/** adds `this` to `this` in constant time and returns the result */
inline fun ECPoint.ct_double(): ECPoint {
return this.curve.math.ct_double(this)
}

@Suppress("NOTHING_TO_INLINE")
Expand All @@ -135,20 +215,13 @@ inline operator fun ECPoint.Normalized.unaryMinus() =
@Suppress("NOTHING_TO_INLINE")
inline operator fun ECPoint.minus(other: ECPoint) = this + (-other)

// TODO: i'm sure this could be smarter (keyword: "comb")
// i'm also sure this isn't resistant to timing side channels if that is something you care about
operator fun BigInteger.times(point: ECPoint): ECPoint {
var o = point
var sum = if (this.bitAt(0)) point else point.curve.IDENTITY
/* double-and-add */
for (i in 1L..<this.bitLength()) {
/* we double o on each iteration (it is (2^i)*point) */
o = o.double()
/* and decide whether to add it based on the bit */
if (this.bitAt(i)) sum += o
}
return sum
}
inline infix fun ECPoint.ct_minus(other: ECPoint) = this ct_plus (-other)

inline operator fun BigInteger.times(point: ECPoint) =
point.curve.math.mul(this, point)

inline infix fun BigInteger.ct_mul(point: ECPoint) =
point.curve.math.ct_mul(this, point)

@Suppress("NOTHING_TO_INLINE")
inline operator fun Int.times(point: ECPoint) =
Expand All @@ -172,6 +245,19 @@ inline operator fun ModularBigInteger.times(point: ECPoint): ECPoint {
return this.residue.times(point)
}

inline infix fun Int.ct_mul(point: ECPoint) =
BigInteger.fromInt(this).ct_mul(point)
inline infix fun Long.ct_mul(point: ECPoint) =
BigInteger.fromLong(this).ct_mul(point)
inline infix fun UInt.ct_mul(point: ECPoint) =
BigInteger.fromUInt(this).ct_mul(point)
inline infix fun ULong.ct_mul(point: ECPoint) =
BigInteger.fromULong(this).ct_mul(point)
inline infix fun ModularBigInteger.ct_mul(point: ECPoint): ECPoint {
require(this.modulus == point.curve.order)
return this.residue.ct_mul(point)
}

/* these are intentionally not operator functions! */
@Suppress("NOTHING_TO_INLINE")
inline fun ECPoint.times(v: BigInteger) = v * this
Expand Down Expand Up @@ -208,20 +294,3 @@ fun straussShamir(u: BigInteger, G: ECPoint, v: BigInteger, Q: ECPoint): ECPoint
}
return R
}

/** computes pQ in constant time */
fun montgomeryMul(k: BigInteger, P: ECPoint): ECPoint {
var R0 = P
var R1 = P.double()
var i = k.bitLength().toLong()-1
while (--i >= 0) {
if (k.bitAt(i)) {
R0 = (R0+R1)
R1 = R1.double()
} else {
R1 = (R0+R1)
R0 = R0.double()
}
}
return R0
}
Loading

0 comments on commit 385f2fa

Please sign in to comment.