diff --git a/std/algebra/emulated/sw_emulated/point.go b/std/algebra/emulated/sw_emulated/point.go index 0132cc3e8f..0dbf653b6a 100644 --- a/std/algebra/emulated/sw_emulated/point.go +++ b/std/algebra/emulated/sw_emulated/point.go @@ -330,6 +330,54 @@ func (c *Curve[B, S]) doubleAndAdd(p, q *AffinePoint[B]) *AffinePoint[B] { } +// doubleAndAddSelect is the same as doubleAndAdd but computes either: +// +// 2p+q is b=1 or +// 2q+p is b=0 +// +// It first computes the x-coordinate of p+q via the slope(p,q) +// and then based on a Select adds either p or q. +func (c *Curve[B, S]) doubleAndAddSelect(b frontend.Variable, p, q *AffinePoint[B]) *AffinePoint[B] { + + // compute λ1 = (q.y-p.y)/(q.x-p.x) + yqyp := c.baseApi.Sub(&q.Y, &p.Y) + xqxp := c.baseApi.Sub(&q.X, &p.X) + λ1 := c.baseApi.Div(yqyp, xqxp) + + // compute x2 = λ1²-p.x-q.x + λ1λ1 := c.baseApi.MulMod(λ1, λ1) + xqxp = c.baseApi.Add(&p.X, &q.X) + x2 := c.baseApi.Sub(λ1λ1, xqxp) + + // ommit y2 computation + + // conditional second addition + t := c.Select(b, p, q) + + // compute λ2 = -λ1-2*t.y/(x2-t.x) + ypyp := c.baseApi.Add(&t.Y, &t.Y) + x2xp := c.baseApi.Sub(x2, &t.X) + λ2 := c.baseApi.Div(ypyp, x2xp) + λ2 = c.baseApi.Add(λ1, λ2) + λ2 = c.baseApi.Neg(λ2) + + // compute x3 =λ2²-t.x-x3 + λ2λ2 := c.baseApi.MulMod(λ2, λ2) + x3 := c.baseApi.Sub(λ2λ2, &t.X) + x3 = c.baseApi.Sub(x3, x2) + + // compute y3 = λ2*(t.x - x3)-t.y + y3 := c.baseApi.Sub(&t.X, x3) + y3 = c.baseApi.Mul(λ2, y3) + y3 = c.baseApi.Sub(y3, &t.Y) + + return &AffinePoint[B]{ + X: *c.baseApi.Reduce(x3), + Y: *c.baseApi.Reduce(y3), + } + +} + // Select selects between p and q given the selector b. If b == 1, then returns // p and q otherwise. func (c *Curve[B, S]) Select(b frontend.Variable, p, q *AffinePoint[B]) *AffinePoint[B] { @@ -393,17 +441,13 @@ func (c *Curve[B, S]) ScalarMul(p *AffinePoint[B], s *emulated.Element[S]) *Affi R1 := c.Select(sBits[1], p, Rb) for i := 2; i < n-1; i++ { - Rb = c.Select(sBits[i], R0, R1) - Rk := c.Select(sBits[i], R1, R0) - Rb = c.doubleAndAdd(Rb, Rk) + Rb = c.doubleAndAddSelect(sBits[i], R0, R1) R0 = c.Select(sBits[i], Rb, R0) R1 = c.Select(sBits[i], R1, Rb) } // i = n-1 - Rb = c.Select(sBits[n-1], R0, R1) - Rk := c.Select(sBits[n-1], R1, R0) - Rb = c.doubleAndAdd(Rb, Rk) + Rb = c.doubleAndAddSelect(sBits[n-1], R0, R1) R0 = c.Select(sBits[n-1], Rb, R0) // i = 0 @@ -499,9 +543,7 @@ func (c *Curve[B, S]) JointScalarMulBase(p *AffinePoint[B], s2, s1 *emulated.Ele R0 := c.Select(s2Bits[1], Rb, p) R1 := c.Select(s2Bits[1], p, Rb) // i = 2 - Rb = c.Select(s2Bits[2], R0, R1) - Rk := c.Select(s2Bits[2], R1, R0) - Rb = c.doubleAndAdd(Rb, Rk) + Rb = c.doubleAndAddSelect(s2Bits[2], R0, R1) R0 = c.Select(s2Bits[2], Rb, R0) R1 = c.Select(s2Bits[2], R1, Rb) @@ -511,11 +553,10 @@ func (c *Curve[B, S]) JointScalarMulBase(p *AffinePoint[B], s2, s1 *emulated.Ele tmp1 := c.add(res1, &gm[i]) res1 = c.Select(s1Bits[i], tmp1, res1) // var-base - Rb = c.Select(s2Bits[i], R0, R1) - Rk = c.Select(s2Bits[i], R1, R0) - Rb = c.doubleAndAdd(Rb, Rk) + Rb = c.doubleAndAddSelect(s2Bits[i], R0, R1) R0 = c.Select(s2Bits[i], Rb, R0) R1 = c.Select(s2Bits[i], R1, Rb) + } // i = n-2 @@ -523,9 +564,7 @@ func (c *Curve[B, S]) JointScalarMulBase(p *AffinePoint[B], s2, s1 *emulated.Ele tmp1 := c.add(res1, &gm[n-2]) res1 = c.Select(s1Bits[n-2], tmp1, res1) // var-base - Rb = c.Select(s2Bits[n-2], R0, R1) - Rk = c.Select(s2Bits[n-2], R1, R0) - Rb = c.doubleAndAdd(Rb, Rk) + Rb = c.doubleAndAddSelect(s2Bits[n-2], R0, R1) R0 = c.Select(s2Bits[n-2], Rb, R0) R1 = c.Select(s2Bits[n-2], R1, Rb) @@ -534,9 +573,7 @@ func (c *Curve[B, S]) JointScalarMulBase(p *AffinePoint[B], s2, s1 *emulated.Ele tmp1 = c.add(res1, &gm[n-1]) res1 = c.Select(s1Bits[n-1], tmp1, res1) // var-base - Rb = c.Select(s2Bits[n-1], R0, R1) - Rk = c.Select(s2Bits[n-1], R1, R0) - Rb = c.doubleAndAdd(Rb, Rk) + Rb = c.doubleAndAddSelect(s2Bits[n-1], R0, R1) R0 = c.Select(s2Bits[n-1], Rb, R0) // i = 0 @@ -544,7 +581,7 @@ func (c *Curve[B, S]) JointScalarMulBase(p *AffinePoint[B], s2, s1 *emulated.Ele tmp1 = c.add(res1, c.Neg(g)) res1 = c.Select(s1Bits[0], res1, tmp1) // var-base - R0 = c.Select(s2Bits[0], R0, c.AddUnified(R0, c.Neg(p))) + R0 = c.Select(s2Bits[0], R0, c.add(R0, c.Neg(p))) return c.add(res1, R0) }