Skip to content

Commit

Permalink
Merge pull request #783 from Consensys/perf/emulated-scalarMul
Browse files Browse the repository at this point in the history
Perf: save 1 `Select` at each iteration in the emulated `scalarMul`
  • Loading branch information
yelhousni authored Jul 25, 2023
2 parents 55b9478 + e1cb5a7 commit ceed757
Showing 1 changed file with 56 additions and 19 deletions.
75 changes: 56 additions & 19 deletions std/algebra/emulated/sw_emulated/point.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,54 @@ func (c *Curve[B, S]) doubleAndAdd(p, q *AffinePoint[B]) *AffinePoint[B] {

}

// doubleAndAddSelect is the same as doubleAndAdd but computes either:
//
// 2p+q is b=1 or
// 2q+p is b=0
//
// It first computes the x-coordinate of p+q via the slope(p,q)
// and then based on a Select adds either p or q.
func (c *Curve[B, S]) doubleAndAddSelect(b frontend.Variable, p, q *AffinePoint[B]) *AffinePoint[B] {

// compute λ1 = (q.y-p.y)/(q.x-p.x)
yqyp := c.baseApi.Sub(&q.Y, &p.Y)
xqxp := c.baseApi.Sub(&q.X, &p.X)
λ1 := c.baseApi.Div(yqyp, xqxp)

// compute x2 = λ1²-p.x-q.x
λ1λ1 := c.baseApi.MulMod(λ1, λ1)
xqxp = c.baseApi.Add(&p.X, &q.X)
x2 := c.baseApi.Sub(λ1λ1, xqxp)

// ommit y2 computation

// conditional second addition
t := c.Select(b, p, q)

// compute λ2 = -λ1-2*t.y/(x2-t.x)
ypyp := c.baseApi.Add(&t.Y, &t.Y)
x2xp := c.baseApi.Sub(x2, &t.X)
λ2 := c.baseApi.Div(ypyp, x2xp)
λ2 = c.baseApi.Add(λ1, λ2)
λ2 = c.baseApi.Neg(λ2)

// compute x3 =λ2²-t.x-x3
λ2λ2 := c.baseApi.MulMod(λ2, λ2)
x3 := c.baseApi.Sub(λ2λ2, &t.X)
x3 = c.baseApi.Sub(x3, x2)

// compute y3 = λ2*(t.x - x3)-t.y
y3 := c.baseApi.Sub(&t.X, x3)
y3 = c.baseApi.Mul(λ2, y3)
y3 = c.baseApi.Sub(y3, &t.Y)

return &AffinePoint[B]{
X: *c.baseApi.Reduce(x3),
Y: *c.baseApi.Reduce(y3),
}

}

// Select selects between p and q given the selector b. If b == 1, then returns
// p and q otherwise.
func (c *Curve[B, S]) Select(b frontend.Variable, p, q *AffinePoint[B]) *AffinePoint[B] {
Expand Down Expand Up @@ -393,17 +441,13 @@ func (c *Curve[B, S]) ScalarMul(p *AffinePoint[B], s *emulated.Element[S]) *Affi
R1 := c.Select(sBits[1], p, Rb)

for i := 2; i < n-1; i++ {
Rb = c.Select(sBits[i], R0, R1)
Rk := c.Select(sBits[i], R1, R0)
Rb = c.doubleAndAdd(Rb, Rk)
Rb = c.doubleAndAddSelect(sBits[i], R0, R1)
R0 = c.Select(sBits[i], Rb, R0)
R1 = c.Select(sBits[i], R1, Rb)
}

// i = n-1
Rb = c.Select(sBits[n-1], R0, R1)
Rk := c.Select(sBits[n-1], R1, R0)
Rb = c.doubleAndAdd(Rb, Rk)
Rb = c.doubleAndAddSelect(sBits[n-1], R0, R1)
R0 = c.Select(sBits[n-1], Rb, R0)

// i = 0
Expand Down Expand Up @@ -499,9 +543,7 @@ func (c *Curve[B, S]) JointScalarMulBase(p *AffinePoint[B], s2, s1 *emulated.Ele
R0 := c.Select(s2Bits[1], Rb, p)
R1 := c.Select(s2Bits[1], p, Rb)
// i = 2
Rb = c.Select(s2Bits[2], R0, R1)
Rk := c.Select(s2Bits[2], R1, R0)
Rb = c.doubleAndAdd(Rb, Rk)
Rb = c.doubleAndAddSelect(s2Bits[2], R0, R1)
R0 = c.Select(s2Bits[2], Rb, R0)
R1 = c.Select(s2Bits[2], R1, Rb)

Expand All @@ -511,21 +553,18 @@ func (c *Curve[B, S]) JointScalarMulBase(p *AffinePoint[B], s2, s1 *emulated.Ele
tmp1 := c.add(res1, &gm[i])
res1 = c.Select(s1Bits[i], tmp1, res1)
// var-base
Rb = c.Select(s2Bits[i], R0, R1)
Rk = c.Select(s2Bits[i], R1, R0)
Rb = c.doubleAndAdd(Rb, Rk)
Rb = c.doubleAndAddSelect(s2Bits[i], R0, R1)
R0 = c.Select(s2Bits[i], Rb, R0)
R1 = c.Select(s2Bits[i], R1, Rb)

}

// i = n-2
// fixed-base
tmp1 := c.add(res1, &gm[n-2])
res1 = c.Select(s1Bits[n-2], tmp1, res1)
// var-base
Rb = c.Select(s2Bits[n-2], R0, R1)
Rk = c.Select(s2Bits[n-2], R1, R0)
Rb = c.doubleAndAdd(Rb, Rk)
Rb = c.doubleAndAddSelect(s2Bits[n-2], R0, R1)
R0 = c.Select(s2Bits[n-2], Rb, R0)
R1 = c.Select(s2Bits[n-2], R1, Rb)

Expand All @@ -534,17 +573,15 @@ func (c *Curve[B, S]) JointScalarMulBase(p *AffinePoint[B], s2, s1 *emulated.Ele
tmp1 = c.add(res1, &gm[n-1])
res1 = c.Select(s1Bits[n-1], tmp1, res1)
// var-base
Rb = c.Select(s2Bits[n-1], R0, R1)
Rk = c.Select(s2Bits[n-1], R1, R0)
Rb = c.doubleAndAdd(Rb, Rk)
Rb = c.doubleAndAddSelect(s2Bits[n-1], R0, R1)
R0 = c.Select(s2Bits[n-1], Rb, R0)

// i = 0
// fixed-base
tmp1 = c.add(res1, c.Neg(g))
res1 = c.Select(s1Bits[0], res1, tmp1)
// var-base
R0 = c.Select(s2Bits[0], R0, c.AddUnified(R0, c.Neg(p)))
R0 = c.Select(s2Bits[0], R0, c.add(R0, c.Neg(p)))

return c.add(res1, R0)
}

0 comments on commit ceed757

Please sign in to comment.