Skip to content

Commit

Permalink
Merge pull request #1023 from Consensys/fix/ec-edgecases
Browse files Browse the repository at this point in the history
Fix: edge cases in `std/algebra` elliptic curve arithmetic circuit (`emulated` and `2-chains)`
  • Loading branch information
yelhousni authored Jan 29, 2024
2 parents 179023d + f3c3eeb commit 4c81525
Show file tree
Hide file tree
Showing 13 changed files with 1,631 additions and 1,068 deletions.
98 changes: 67 additions & 31 deletions std/algebra/emulated/sw_emulated/point.go
Original file line number Diff line number Diff line change
Expand Up @@ -604,14 +604,12 @@ func (c *Curve[B, S]) scalarMulGeneric(p *AffinePoint[B], s *emulated.Element[S]
if err != nil {
panic(fmt.Sprintf("parse opts: %v", err))
}
addFn := c.Add
var selector frontend.Variable
if cfg.UseSafe {
// if p=(0,0) we assign a dummy (0,1) to p and continue
selector = c.api.And(c.baseApi.IsZero(&p.X), c.baseApi.IsZero(&p.Y))
one := c.baseApi.One()
p = c.Select(selector, &AffinePoint[B]{X: *one, Y: *one}, p)
addFn = c.AddUnified
}

var st S
Expand All @@ -638,9 +636,10 @@ func (c *Curve[B, S]) scalarMulGeneric(p *AffinePoint[B], s *emulated.Element[S]
R0 = c.Select(sBits[n-1], Rb, R0)

// i = 0
// When cfg.UseSafe is set, we use AddUnified instead of Add. This means
// when s=0 then Acc=(0,0) because AddUnified(Q, -Q) = (0,0).
R0 = c.Select(sBits[0], R0, addFn(R0, c.Neg(p)))
// we use AddUnified instead of Add. This is because:
// - when s=0 then R0=P and AddUnified(P, -P) = (0,0). We return (0,0).
// - when s=1 then R0=P AddUnified(Q, -Q) is well defined. We return R0=P.
R0 = c.Select(sBits[0], R0, c.AddUnified(R0, c.Neg(p)))

if cfg.UseSafe {
// if p=(0,0), return (0,0)
Expand All @@ -665,26 +664,36 @@ func (c *Curve[B, S]) jointScalarMul(p1, p2 *AffinePoint[B], s1, s2 *emulated.El
}
}

// jointScalarMulGeneric computes s1 * p1 + s2 * p2 and returns it. It doesn't modify the inputs.
// jointScalarMulGeneric computes [s1]p1 + [s2]p2. It doesn't modify the inputs.
//
// ⚠️ The scalar s must be nonzero and the point Q different from (0,0).
// ⚠️ p1, p2 must not be (0,0) and s1, s2 must not be 0, unless [algopts.WithUseSafe] option is set.
func (c *Curve[B, S]) jointScalarMulGeneric(p1, p2 *AffinePoint[B], s1, s2 *emulated.Element[S], opts ...algopts.AlgebraOption) *AffinePoint[B] {
s1r := c.scalarApi.Reduce(s1)
s1Bits := c.scalarApi.ToBits(s1r)
s2r := c.scalarApi.Reduce(s2)
s2Bits := c.scalarApi.ToBits(s2r)

res := c.scalarBitsMul(p1, s1Bits, opts...)
tmp := c.scalarBitsMul(p2, s2Bits, opts...)

res = c.Add(res, tmp)
return res
res1 := c.scalarMulGeneric(p1, s1, opts...)
res2 := c.scalarMulGeneric(p2, s2, opts...)
return c.Add(res1, res2)
}

// jointScalarMulGLV computes P = [s]Q + [t]R using Shamir's trick with an efficient endomorphism and returns it. It doesn't modify P, Q nor s.
// jointScalarMulGLV computes [s1]p1 + [s2]p2 using an endomorphism. It doesn't modify P, Q nor s.
//
// ⚠️ The scalar s must be nonzero and the point Q different from (0,0).
func (c *Curve[B, S]) jointScalarMulGLV(Q, R *AffinePoint[B], s, t *emulated.Element[S], opts ...algopts.AlgebraOption) *AffinePoint[B] {
// ⚠️ The scalars s1, s2 must be nonzero and the point p1, p2 different from (0,0), unless [algopts.WithUseSafe] option is set.
func (c *Curve[B, S]) jointScalarMulGLV(p1, p2 *AffinePoint[B], s1, s2 *emulated.Element[S], opts ...algopts.AlgebraOption) *AffinePoint[B] {
cfg, err := algopts.NewConfig(opts...)
if err != nil {
panic(fmt.Sprintf("parse opts: %v", err))
}
if cfg.UseSafe {
// TODO @yelhousni: optimize
res1 := c.scalarMulGLV(p1, s1, opts...)
res2 := c.scalarMulGLV(p2, s2, opts...)
return c.AddUnified(res1, res2)
} else {
return c.jointScalarMulGLVUnsafe(p1, p2, s1, s2)
}
}

// jointScalarMulGLVUnsafe computes [s]Q + [t]R using Shamir's trick with an efficient endomorphism and returns it. It doesn't modify P, Q nor s.
// ⚠️ The scalar s must be nonzero and the point Q different from (0,0), unless [algopts.WithUseSafe] option is set.
func (c *Curve[B, S]) jointScalarMulGLVUnsafe(Q, R *AffinePoint[B], s, t *emulated.Element[S]) *AffinePoint[B] {
var st S
frModulus := c.scalarApi.Modulus()
sd, err := c.scalarApi.NewHint(decomposeScalarG1, 5, s, c.eigenvalue, frModulus)
Expand Down Expand Up @@ -806,15 +815,23 @@ func (c *Curve[B, S]) jointScalarMulGLV(Q, R *AffinePoint[B], s, t *emulated.Ele
Acc = c.Select(t2bits[0], Acc, tablePhiR[0])

return Acc

}

// scalarBitsMul computes s * p and returns it where sBits is the bit decomposition of s. It doesn't modify p nor sBits.
// ⚠️ Point and scalar must be nonzero.
func (c *Curve[B, S]) scalarBitsMul(p *AffinePoint[B], sBits []frontend.Variable, opts ...algopts.AlgebraOption) *AffinePoint[B] {
// scalarBitsMulGeneric computes s * p and returns it where sBits is the bit decomposition of s. It doesn't modify p nor sBits.
// ⚠️ p must not be (0,0) and sBits not [0,...,0], unless [algopts.WithUseSafe] option is set.
func (c *Curve[B, S]) scalarBitsMulGeneric(p *AffinePoint[B], sBits []frontend.Variable, opts ...algopts.AlgebraOption) *AffinePoint[B] {
cfg, err := algopts.NewConfig(opts...)
if err != nil {
panic(fmt.Sprintf("parse opts: %v", err))
}
var selector frontend.Variable
if cfg.UseSafe {
// if p=(0,0) we assign a dummy (0,1) to p and continue
selector = c.api.And(c.baseApi.IsZero(&p.X), c.baseApi.IsZero(&p.Y))
one := c.baseApi.One()
p = c.Select(selector, &AffinePoint[B]{X: *one, Y: *one}, p)
}

var st S
n := st.Modulus().BitLen()
Expand All @@ -838,8 +855,17 @@ func (c *Curve[B, S]) scalarBitsMul(p *AffinePoint[B], sBits []frontend.Variable
R0 = c.Select(sBits[n-1], Rb, R0)

// i = 0
// we use AddUnified instead of Add. This is because:
// - when s=0 then R0=P and AddUnified(P, -P) = (0,0). We return (0,0).
// - when s=1 then R0=P AddUnified(Q, -Q) is well defined. We return R0=P.
R0 = c.Select(sBits[0], R0, c.AddUnified(R0, c.Neg(p)))

if cfg.UseSafe {
// if p=(0,0), return (0,0)
zero := c.baseApi.Zero()
R0 = c.Select(selector, &AffinePoint[B]{X: *zero, Y: *zero}, R0)
}

return R0
}

Expand Down Expand Up @@ -884,7 +910,13 @@ func (c *Curve[B, S]) ScalarMulBase(s *emulated.Element[S], opts ...algopts.Alge
}

// i = 0
tmp := c.AddUnified(res, c.Neg(g))
// When cfg.UseSafe is set, we use AddUnified instead of Add. This means
// when s=0 then Acc=(0,0) because AddUnified(Q, -Q) = (0,0).
addFn := c.Add
if cfg.UseSafe {
addFn = c.AddUnified
}
tmp := addFn(res, c.Neg(g))
res = c.Select(sBits[0], res, tmp)

return res
Expand All @@ -911,8 +943,8 @@ func (c *Curve[B, S]) ScalarMulBase(s *emulated.Element[S], opts ...algopts.Alge
//
// This saves the Select logic related to (0,0) and the use of AddUnified to
// handle the 0-scalar edge case.
func (c *Curve[B, S]) JointScalarMulBase(p *AffinePoint[B], s2, s1 *emulated.Element[S]) *AffinePoint[B] {
return c.jointScalarMul(c.Generator(), p, s1, s2)
func (c *Curve[B, S]) JointScalarMulBase(p *AffinePoint[B], s2, s1 *emulated.Element[S], opts ...algopts.AlgebraOption) *AffinePoint[B] {
return c.jointScalarMul(c.Generator(), p, s1, s2, opts...)
}

// MultiScalarMul computes the multi scalar multiplication of the points P and
Expand All @@ -932,6 +964,10 @@ func (c *Curve[B, S]) MultiScalarMul(p []*AffinePoint[B], s []*emulated.Element[
if err != nil {
return nil, fmt.Errorf("new config: %w", err)
}
addFn := c.Add
if cfg.UseSafe {
addFn = c.AddUnified
}
if !cfg.FoldMulti {
// the scalars are unique
if len(p) != len(s) {
Expand All @@ -946,7 +982,7 @@ func (c *Curve[B, S]) MultiScalarMul(p []*AffinePoint[B], s []*emulated.Element[
}
for i := 1; i < n-1; i += 2 {
q := c.jointScalarMul(p[i-1], p[i], s[i-1], s[i], opts...)
res = c.Add(res, q)
res = addFn(res, q)
}
return res, nil
} else {
Expand All @@ -957,12 +993,12 @@ func (c *Curve[B, S]) MultiScalarMul(p []*AffinePoint[B], s []*emulated.Element[
gamma := s[0]
gamma = c.scalarApi.Reduce(gamma)
gammaBits := c.scalarApi.ToBits(gamma)
res := c.scalarBitsMul(p[len(p)-1], gammaBits, opts...)
res := c.scalarBitsMulGeneric(p[len(p)-1], gammaBits, opts...)
for i := len(p) - 2; i > 0; i-- {
res = c.Add(p[i], res)
res = c.scalarBitsMul(res, gammaBits, opts...)
res = addFn(p[i], res)
res = c.scalarBitsMulGeneric(res, gammaBits, opts...)
}
res = c.Add(p[0], res)
res = addFn(p[0], res)
return res, nil
}
}
Loading

0 comments on commit 4c81525

Please sign in to comment.