From 4545ee118338f16d8b84cbc80370dfe9a0dae50a Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Wed, 21 Feb 2024 10:57:28 +0000 Subject: [PATCH 01/15] perf: sum over limbs for IsZero --- std/math/emulated/field_assert.go | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/std/math/emulated/field_assert.go b/std/math/emulated/field_assert.go index 692db8ef2e..9bbc9df1a9 100644 --- a/std/math/emulated/field_assert.go +++ b/std/math/emulated/field_assert.go @@ -196,11 +196,31 @@ func (f *Field[T]) AssertIsInRange(a *Element[T]) { func (f *Field[T]) IsZero(a *Element[T]) frontend.Variable { ca := f.Reduce(a) f.AssertIsInRange(ca) - res := f.api.IsZero(ca.Limbs[0]) + // we use two approaches for checking if the element is exactly zero. The + // first approach is to check that every limb individually is zero. The + // second approach is to check if the sum of all limbs is zero. Usually, we + // cannot use this approach as we could have false positive due to overflow + // in the native field. However, as the widths of the limbs are restricted, + // then we can ensure in most cases that no overflows happen. + + // as ca is already reduced, then every limb overflow is already 0. Only + // every addition adds a bit to the overflow + totalOverflow := len(ca.Limbs) - 1 + if totalOverflow < int(f.maxOverflow()) { + // the sums of limbs would overflow the native field. Use the first + // approach instead. + res := f.api.IsZero(ca.Limbs[0]) + for i := 1; i < len(ca.Limbs); i++ { + res = f.api.Mul(res, f.api.IsZero(ca.Limbs[i])) + } + return res + } + // default case, limbs sum does not overflow the native field + limbSum := ca.Limbs[0] for i := 1; i < len(ca.Limbs); i++ { - res = f.api.Mul(res, f.api.IsZero(ca.Limbs[i])) + limbSum = f.api.Add(limbSum, ca.Limbs[i]) } - return res + return f.api.IsZero(limbSum) } // // Cmp returns: From 185d09c413fe73d3ba4ea7bd6ee2fa814a0881e3 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Wed, 21 Feb 2024 15:08:32 +0000 Subject: [PATCH 02/15] perf: use mulmod for equality assertion --- std/math/emulated/field_assert.go | 14 +------ std/math/emulated/field_mul.go | 67 +++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 13 deletions(-) diff --git a/std/math/emulated/field_assert.go b/std/math/emulated/field_assert.go index 9bbc9df1a9..dc991dbc9c 100644 --- a/std/math/emulated/field_assert.go +++ b/std/math/emulated/field_assert.go @@ -129,19 +129,7 @@ func (f *Field[T]) AssertIsEqual(a, b *Element[T]) { } diff := f.Sub(b, a) - - // we compute k such that diff / p == k - // so essentially, we say "I know an element k such that k*p == diff" - // hence, diff == 0 mod p - p := f.Modulus() - k, err := f.computeQuoHint(diff) - if err != nil { - panic(fmt.Sprintf("hint error: %v", err)) - } - - kp := f.reduceAndOp(f.mul, f.mulPreCond, k, p) - - f.AssertLimbsEquality(diff, kp) + f.checkZero(diff) } // AssertIsLessOrEqual ensures that e is less or equal than a. For proper diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index 964a4f3058..a6f4b7b3b3 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -122,6 +122,26 @@ func (f *Field[T]) mulMod(a, b *Element[T], _ uint) *Element[T] { return r } +// checkZero creates multiplication check a * 1 = 0 + k*p. +func (f *Field[T]) checkZero(a *Element[T]) { + // the method works similarly to mulMod, but we know that we are multiplying + // by one and expected result should be zero. + f.enforceWidthConditional(a) + k, r, c, err := f.callCheckZeroHint(a) + if err != nil { + panic(err) + } + mc := mulCheck[T]{ + f: f, + a: a, + b: f.shortOne(), // one on single limb to speed up the polynomial evaluation + c: c, + k: k, + r: r, // expected to be zero on zero limbs. + } + f.mulChecks = append(f.mulChecks, mc) +} + // evalWithChallenge represents element a as a polynomial a(X) and evaluates at // at[0]. For efficiency, we use already evaluated powers of at[0] given by at. // It stores the evaluation result inside the Element and marks it as evaluated. @@ -243,6 +263,53 @@ func (f *Field[T]) callMulHint(a, b *Element[T]) (quo, rem, carries *Element[T], return } +func (f *Field[T]) callCheckZeroHint(a *Element[T]) (quo, rem, carries *Element[T], err error) { + nbLimbs, nbBits := f.fParams.NbLimbs(), f.fParams.BitsPerLimb() + // this is the same as callMulHint, but we know that we multiply a by one. + // This allows to optimize some bounds - namely we know that the overflow of + // the product would be the same as the overflow of a as we would be + // multiplying every limb of a by one. Secondly, the number of limbs of the + // result is also the same as for a. + nextOverflow := a.overflow + nbActualQuoLimbs := (uint(len(a.Limbs))*nbBits + nextOverflow + 1 - // + uint(f.fParams.Modulus().BitLen()) + // + nbBits - 1) / + nbBits + + // also compute the number of limbs assuming that we have 1 on full + // number of limbs to keep compatibility with the multiplication hint + // (which assumes that the number of limbs is the same as the input). + // Otherwise, we would have to provide additional inputs to indicate the + // length of the quotient. + nbFullQuoLimbs := ((2*nbLimbs-1)*nbBits + nextOverflow + 1 - // + uint(f.fParams.Modulus().BitLen()) + // + nbBits - 1) / + nbBits + nbRemLimbs := nbLimbs + nbCarryLimbs := (nbFullQuoLimbs + nbLimbs) - 2 + hintInputs := []frontend.Variable{ + nbBits, + nbLimbs, + } + hintInputs = append(hintInputs, f.Modulus().Limbs...) + hintInputs = append(hintInputs, a.Limbs...) + hintInputs = append(hintInputs, f.One().Limbs...) + ret, err := f.api.NewHint(mulHint, int(nbFullQuoLimbs)+int(nbRemLimbs)+int(nbCarryLimbs), hintInputs...) + if err != nil { + err = fmt.Errorf("call hint: %w", err) + return + } + // now, we only pack the number of actual expected non-zero limbs for + // quotient. This makes later in the multiplication check later cheaper as + // we evaluate the polynomial with limbs as coefficients. + quo = f.packLimbs(ret[:nbActualQuoLimbs], false) + // and the remainder is supposed to be 0. As a polynomial this means it will + // be evaluated to zero. + rem = &Element[T]{} + carries = f.newInternalElement(ret[nbFullQuoLimbs+nbRemLimbs:], 0) + return +} + func mulHint(field *big.Int, inputs, outputs []*big.Int) error { nbBits := int(inputs[0].Int64()) nbLimbs := int(inputs[1].Int64()) From 41c410c3a917b0696b46192be54eb5677fc59e3c Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Wed, 21 Feb 2024 15:09:02 +0000 Subject: [PATCH 03/15] fix: handle edge case in mulcheck with zero limbs --- std/math/emulated/field_mul.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index a6f4b7b3b3..e835b8adc9 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -153,7 +153,10 @@ func (f *Field[T]) evalWithChallenge(a *Element[T], at []frontend.Variable) *Ele if len(at) < len(a.Limbs)-1 { panic("evaluation powers less than limbs") } - sum := f.api.Mul(a.Limbs[0], 1) // copy because we use MulAcc + var sum frontend.Variable = 0 + if len(a.Limbs) > 0 { + sum = f.api.Mul(a.Limbs[0], 1) // copy because we use MulAcc + } for i := 1; i < len(a.Limbs); i++ { sum = f.api.MulAcc(sum, a.Limbs[i], at[i-1]) } From 7f47faf77fdd0deab3488a1f28228164a2536ca1 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Wed, 21 Feb 2024 15:09:16 +0000 Subject: [PATCH 04/15] refactor: do not use temp var --- std/math/emulated/field_mul.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index e835b8adc9..0e94fc4aed 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -199,10 +199,9 @@ func (f *Field[T]) performMulChecks(api frontend.API) error { coefsLen = max(coefsLen, len(f.mulChecks[i].c.Limbs)) } at := make([]frontend.Variable, coefsLen) - var prev frontend.Variable = 1 - for i := range at { - at[i] = api.Mul(prev, commitment) - prev = at[i] + at[0] = commitment + for i := 1; i < len(at); i++ { + at[i] = api.Mul(at[i-1], commitment) } // evaluate all r, k, c for i := range f.mulChecks { From a9ebbef0752d3ec543eca09f7d2d17c05e54d656 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Wed, 21 Feb 2024 15:10:46 +0000 Subject: [PATCH 05/15] feat: remove AssertLimbsEquality --- std/math/emulated/element_test.go | 31 ---------------------------- std/math/emulated/field_assert.go | 34 ------------------------------- 2 files changed, 65 deletions(-) diff --git a/std/math/emulated/element_test.go b/std/math/emulated/element_test.go index 1bd7521f3e..fec7b6e0ba 100644 --- a/std/math/emulated/element_test.go +++ b/std/math/emulated/element_test.go @@ -17,42 +17,11 @@ import ( const testCurve = ecc.BN254 -type AssertLimbEqualityCircuit[T FieldParams] struct { - A, B Element[T] -} - -func (c *AssertLimbEqualityCircuit[T]) Define(api frontend.API) error { - f, err := NewField[T](api) - if err != nil { - return err - } - f.AssertLimbsEquality(&c.A, &c.B) - return nil -} - func testName[T FieldParams]() string { var fp T return fmt.Sprintf("%s/limb=%d", reflect.TypeOf(fp).Name(), fp.BitsPerLimb()) } -func TestAssertLimbEqualityNoOverflow(t *testing.T) { - testAssertLimbEqualityNoOverflow[Goldilocks](t) - testAssertLimbEqualityNoOverflow[Secp256k1Fp](t) - testAssertLimbEqualityNoOverflow[BN254Fp](t) -} - -func testAssertLimbEqualityNoOverflow[T FieldParams](t *testing.T) { - var fp T - assert := test.NewAssert(t) - assert.Run(func(assert *test.Assert) { - var circuit, witness AssertLimbEqualityCircuit[T] - val, _ := rand.Int(rand.Reader, fp.Modulus()) - witness.A = ValueOf[T](val) - witness.B = ValueOf[T](val) - assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness)) - }, testName[T]()) -} - // TODO: add also cases which should fail type AssertIsLessEqualThanCircuit[T FieldParams] struct { diff --git a/std/math/emulated/field_assert.go b/std/math/emulated/field_assert.go index dc991dbc9c..384626b2fb 100644 --- a/std/math/emulated/field_assert.go +++ b/std/math/emulated/field_assert.go @@ -2,7 +2,6 @@ package emulated import ( "fmt" - "math/big" "github.com/consensys/gnark/frontend" ) @@ -56,39 +55,6 @@ func (f *Field[T]) rsh(v frontend.Variable, startDigit, endDigit int) frontend.V f.api.AssertIsEqual(composed, v) return shifted[0] } - -// AssertLimbsEquality asserts that the limbs represent a same integer value. -// This method does not ensure that the values are equal modulo the field order. -// For strict equality, use AssertIsEqual. -func (f *Field[T]) AssertLimbsEquality(a, b *Element[T]) { - f.enforceWidthConditional(a) - f.enforceWidthConditional(b) - ba, aConst := f.constantValue(a) - bb, bConst := f.constantValue(b) - if aConst && bConst { - ba.Mod(ba, f.fParams.Modulus()) - bb.Mod(bb, f.fParams.Modulus()) - if ba.Cmp(bb) != 0 { - panic(fmt.Errorf("constant values are different: %s != %s", ba.String(), bb.String())) - } - return - } - - // first, we check if we can compact a and b; they could be using 8 limbs of 32bits - // but with our snark field, we could express them in 2 limbs of 128bits, which would make bit decomposition - // and limbs equality in-circuit (way) cheaper - ca, cb, bitsPerLimb := f.compact(a, b) - - // slow path -- the overflows are different. Need to compare with carries. - // TODO: we previously assumed that one side was "larger" than the other - // side, but I think this assumption is not valid anymore - if a.overflow > b.overflow { - f.assertLimbsEqualitySlow(f.api, ca, cb, bitsPerLimb, a.overflow) - } else { - f.assertLimbsEqualitySlow(f.api, cb, ca, bitsPerLimb, b.overflow) - } -} - // enforceWidth enforces the width of the limbs. When modWidth is true, then the // limbs are asserted to be the width of the modulus (highest limb may be less // than full limb width). Otherwise, every limb is assumed to have same width From 7838ce054ede7b043a79643823b9b6edff548207 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Wed, 21 Feb 2024 15:11:48 +0000 Subject: [PATCH 06/15] feat: implement shortOne() method --- std/math/emulated/field.go | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/std/math/emulated/field.go b/std/math/emulated/field.go index 3480aaab31..70e327a29d 100644 --- a/std/math/emulated/field.go +++ b/std/math/emulated/field.go @@ -30,14 +30,16 @@ type Field[T FieldParams] struct { maxOfOnce sync.Once // constants for often used elements n, 0 and 1. Allocated only once - nConstOnce sync.Once - nConst *Element[T] - nprevConstOnce sync.Once - nprevConst *Element[T] - zeroConstOnce sync.Once - zeroConst *Element[T] - oneConstOnce sync.Once - oneConst *Element[T] + nConstOnce sync.Once + nConst *Element[T] + nprevConstOnce sync.Once + nprevConst *Element[T] + zeroConstOnce sync.Once + zeroConst *Element[T] + oneConstOnce sync.Once + oneConst *Element[T] + shortOneConstOnce sync.Once + shortOneConst *Element[T] log zerolog.Logger @@ -146,6 +148,14 @@ func (f *Field[T]) One() *Element[T] { return f.oneConst } +// shortOne returns one as a constant stored in a single limb. +func (f *Field[T]) shortOne() *Element[T] { + f.shortOneConstOnce.Do(func() { + f.shortOneConst = f.newInternalElement([]frontend.Variable{1}, 0) + }) + return f.shortOneConst +} + // Modulus returns the modulus of the emulated ring as a constant. func (f *Field[T]) Modulus() *Element[T] { f.nConstOnce.Do(func() { From 871ab1c2e00522626b7d37aacf15a025b47734db Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Wed, 21 Feb 2024 15:12:51 +0000 Subject: [PATCH 07/15] chore: remove unused private methods --- std/math/emulated/field.go | 45 ---------- std/math/emulated/field_assert.go | 49 ----------- std/math/emulated/field_mul.go | 44 ---------- std/math/emulated/field_ops.go | 11 --- std/math/emulated/hints.go | 141 ------------------------------ 5 files changed, 290 deletions(-) diff --git a/std/math/emulated/field.go b/std/math/emulated/field.go index 70e327a29d..6c1f19b04d 100644 --- a/std/math/emulated/field.go +++ b/std/math/emulated/field.go @@ -258,51 +258,6 @@ func (f *Field[T]) constantValue(v *Element[T]) (*big.Int, bool) { return res, true } -// compact returns parameters which allow for most optimal regrouping of -// limbs. In regrouping the limbs, we encode multiple existing limbs as a linear -// combination in a single new limb. -// compact returns a and b minimal (in number of limbs) representation that fits in the snark field -func (f *Field[T]) compact(a, b *Element[T]) (ac, bc []frontend.Variable, bitsPerLimb uint) { - // omit width reduction as is done in the calling method already - maxOverflow := max(a.overflow, b.overflow) - // subtract one bit as can not potentially use all bits of Fr and one bit as - // grouping may overflow - maxNbBits := uint(f.api.Compiler().FieldBitLen()) - 2 - maxOverflow - groupSize := maxNbBits / f.fParams.BitsPerLimb() - if groupSize == 0 { - // no space for compact - return a.Limbs, b.Limbs, f.fParams.BitsPerLimb() - } - - bitsPerLimb = f.fParams.BitsPerLimb() * groupSize - - ac = f.compactLimbs(a, groupSize, bitsPerLimb) - bc = f.compactLimbs(b, groupSize, bitsPerLimb) - return -} - -// compactLimbs perform the regrouping of limbs between old and new parameters. -func (f *Field[T]) compactLimbs(e *Element[T], groupSize, bitsPerLimb uint) []frontend.Variable { - if f.fParams.BitsPerLimb() == bitsPerLimb { - return e.Limbs - } - nbLimbs := (uint(len(e.Limbs)) + groupSize - 1) / groupSize - r := make([]frontend.Variable, nbLimbs) - coeffs := make([]*big.Int, groupSize) - one := big.NewInt(1) - for i := range coeffs { - coeffs[i] = new(big.Int) - coeffs[i].Lsh(one, f.fParams.BitsPerLimb()*uint(i)) - } - for i := uint(0); i < nbLimbs; i++ { - r[i] = uint(0) - for j := uint(0); j < groupSize && i*groupSize+j < uint(len(e.Limbs)); j++ { - r[i] = f.api.Add(r[i], f.api.Mul(coeffs[j], e.Limbs[i*groupSize+j])) - } - } - return r -} - // maxOverflow returns the maximal possible overflow for the element. If the // overflow of the next operation exceeds the value returned by this method, // then the limbs may overflow the native field. diff --git a/std/math/emulated/field_assert.go b/std/math/emulated/field_assert.go index 384626b2fb..a2809e4eb9 100644 --- a/std/math/emulated/field_assert.go +++ b/std/math/emulated/field_assert.go @@ -6,55 +6,6 @@ import ( "github.com/consensys/gnark/frontend" ) -// assertLimbsEqualitySlow is the main routine in the package. It asserts that the -// two slices of limbs represent the same integer value. This is also the most -// costly operation in the package as it does bit decomposition of the limbs. -func (f *Field[T]) assertLimbsEqualitySlow(api frontend.API, l, r []frontend.Variable, nbBits, nbCarryBits uint) { - - nbLimbs := max(len(l), len(r)) - maxValue := new(big.Int).Lsh(big.NewInt(1), nbBits+nbCarryBits) - maxValueShift := new(big.Int).Lsh(big.NewInt(1), nbCarryBits) - - var carry frontend.Variable = 0 - for i := 0; i < nbLimbs; i++ { - diff := api.Add(maxValue, carry) - if i < len(l) { - diff = api.Add(diff, l[i]) - } - if i < len(r) { - diff = api.Sub(diff, r[i]) - } - if i > 0 { - diff = api.Sub(diff, maxValueShift) - } - - // carry is stored in the highest bits of diff[nbBits:nbBits+nbCarryBits+1] - // we know that diff[:nbBits] are 0 bits, but still need to constrain them. - // to do both; we do a "clean" right shift and only need to boolean constrain the carry part - carry = f.rsh(diff, int(nbBits), int(nbBits+nbCarryBits+1)) - } - api.AssertIsEqual(carry, maxValueShift) -} - -func (f *Field[T]) rsh(v frontend.Variable, startDigit, endDigit int) frontend.Variable { - // if v is a constant, work with the big int value. - if c, ok := f.api.Compiler().ConstantValue(v); ok { - bits := make([]frontend.Variable, endDigit-startDigit) - for i := 0; i < len(bits); i++ { - bits[i] = c.Bit(i + startDigit) - } - return bits - } - shifted, err := f.api.Compiler().NewHint(RightShift, 1, startDigit, v) - if err != nil { - panic(fmt.Sprintf("right shift: %v", err)) - } - f.checker.Check(shifted[0], endDigit-startDigit) - shift := new(big.Int).Lsh(big.NewInt(1), uint(startDigit)) - composed := f.api.Mul(shifted[0], shift) - f.api.AssertIsEqual(composed, v) - return shifted[0] -} // enforceWidth enforces the width of the limbs. When modWidth is true, then the // limbs are asserted to be the width of the modulus (highest limb may be less // than full limb width). Otherwise, every limb is assumed to have same width diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index 0e94fc4aed..b0f395b9a0 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -453,47 +453,3 @@ func (f *Field[T]) mulPreCond(a, b *Element[T]) (nextOverflow uint, err error) { } return } - -func (f *Field[T]) mul(a, b *Element[T], nextOverflow uint) *Element[T] { - // TODO: kept for [AssertIsEqual]. Consider if this can be removed and we - // can use MulMod for equality assertion. - ba, aConst := f.constantValue(a) - bb, bConst := f.constantValue(b) - if aConst && bConst { - ba.Mul(ba, bb).Mod(ba, f.fParams.Modulus()) - return newConstElement[T](ba) - } - - // mulResult contains the result (out of circuit) of a * b school book multiplication - // len(mulResult) == len(a) + len(b) - 1 - mulResult, err := f.computeMultiplicationHint(a.Limbs, b.Limbs) - if err != nil { - panic(fmt.Sprintf("multiplication hint: %s", err)) - } - - // we computed the result of the mul outside the circuit (mulResult) - // and we want to constrain inside the circuit that this injected value - // actually matches the in-circuit a * b values - // create constraints (\sum_{i=0}^{m-1} a_i c^i) * (\sum_{i=0}^{m-1} b_i - // c^i) = (\sum_{i=0}^{2m-2} z_i c^i) for c \in {1, 2m-1} - w := new(big.Int) - for c := 1; c <= len(mulResult); c++ { - w.SetInt64(1) // c^i - l := f.api.Mul(a.Limbs[0], 1) - r := f.api.Mul(b.Limbs[0], 1) - o := f.api.Mul(mulResult[0], 1) - - for i := 1; i < len(mulResult); i++ { - w.Lsh(w, uint(c)) - if i < len(a.Limbs) { - l = f.api.MulAcc(l, a.Limbs[i], w) - } - if i < len(b.Limbs) { - r = f.api.MulAcc(r, b.Limbs[i], w) - } - o = f.api.MulAcc(o, mulResult[i], w) - } - f.api.AssertIsEqual(f.api.Mul(l, r), o) - } - return f.newInternalElement(mulResult, nextOverflow) -} diff --git a/std/math/emulated/field_ops.go b/std/math/emulated/field_ops.go index b2b96e0ace..4115089a8c 100644 --- a/std/math/emulated/field_ops.go +++ b/std/math/emulated/field_ops.go @@ -153,17 +153,6 @@ func (f *Field[T]) Sub(a, b *Element[T]) *Element[T] { return f.reduceAndOp(f.sub, f.subPreCond, a, b) } -// subReduce returns a-b and returns it. Contrary to [Field[T].Sub] method this -// method does not reduce the inputs if the result would overflow. This method -// is currently only used as a subroutine in [Field[T].Reduce] method to avoid -// infinite recursion when we are working exactly on the overflow limits. -func (f *Field[T]) subNoReduce(a, b *Element[T]) *Element[T] { - nextOverflow, _ := f.subPreCond(a, b) - // we ignore error as it only indicates if we should reduce or not. But we - // are in non-reducing version of sub. - return f.sub(a, b, nextOverflow) -} - func (f *Field[T]) subPreCond(a, b *Element[T]) (nextOverflow uint, err error) { reduceRight := a.overflow < (b.overflow + 1) nextOverflow = max(b.overflow+1, a.overflow) + 1 diff --git a/std/math/emulated/hints.go b/std/math/emulated/hints.go index 16a560f9c7..6c1644c407 100644 --- a/std/math/emulated/hints.go +++ b/std/math/emulated/hints.go @@ -19,27 +19,12 @@ func init() { func GetHints() []solver.Hint { return []solver.Hint{ DivHint, - QuoHint, InverseHint, - MultiplicationHint, - RightShift, SqrtHint, mulHint, } } -// computeMultiplicationHint packs the inputs for the MultiplicationHint hint function. -func (f *Field[T]) computeMultiplicationHint(leftLimbs, rightLimbs []frontend.Variable) (mulLimbs []frontend.Variable, err error) { - hintInputs := []frontend.Variable{ - f.fParams.BitsPerLimb(), - len(leftLimbs), - len(rightLimbs), - } - hintInputs = append(hintInputs, leftLimbs...) - hintInputs = append(hintInputs, rightLimbs...) - return f.api.NewHint(MultiplicationHint, nbMultiplicationResLimbs(len(leftLimbs), len(rightLimbs)), hintInputs...) -} - // nbMultiplicationResLimbs returns the number of limbs which fit the // multiplication result. func nbMultiplicationResLimbs(lenLeft, lenRight int) int { @@ -50,87 +35,6 @@ func nbMultiplicationResLimbs(lenLeft, lenRight int) int { return res } -// MultiplicationHint unpacks the factors and parameters from inputs, computes -// the product and stores it in output. See internal method -// computeMultiplicationHint for the input packing. -func MultiplicationHint(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { - if len(inputs) < 3 { - return fmt.Errorf("input must be at least three elements") - } - nbBits := int(inputs[0].Int64()) - if 2*nbBits+1 >= mod.BitLen() { - return fmt.Errorf("can not fit multiplication result into limb of length %d", nbBits) - } - // TODO: check that the scalar field fits 2*nbBits + nbLimbs. 2*nbBits comes - // from multiplication and nbLimbs comes from additions. - // TODO: check that all limbs all fully reduced - nbLimbsLeft := int(inputs[1].Int64()) - // TODO: get the limb length from the input instead of packing into input - nbLimbsRight := int(inputs[2].Int64()) - if len(inputs) != 3+nbLimbsLeft+nbLimbsRight { - return fmt.Errorf("input invalid") - } - if len(outputs) < nbLimbsLeft+nbLimbsRight-1 { - return fmt.Errorf("can not fit multiplication result into %d limbs", len(outputs)) - } - for _, oi := range outputs { - if oi == nil { - return fmt.Errorf("output not initialized") - } - oi.SetUint64(0) - } - tmp := new(big.Int) - for i, li := range inputs[3 : 3+nbLimbsLeft] { - for j, rj := range inputs[3+nbLimbsLeft:] { - outputs[i+j].Add(outputs[i+j], tmp.Mul(li, rj)) - } - } - return nil -} - -// computeQuoHint packs the inputs for QuoHint function and returns z = x / y -// (discards remainder) -func (f *Field[T]) computeQuoHint(x *Element[T]) (z *Element[T], err error) { - var fp T - resLen := (uint(len(x.Limbs))*fp.BitsPerLimb() + x.overflow + 1 - // diff total bitlength - uint(fp.Modulus().BitLen()) + // subtract modulus bitlength - fp.BitsPerLimb() - 1) / // to round up - fp.BitsPerLimb() - - hintInputs := []frontend.Variable{ - fp.BitsPerLimb(), - len(x.Limbs), - } - p := f.Modulus() - hintInputs = append(hintInputs, x.Limbs...) - hintInputs = append(hintInputs, p.Limbs...) - - limbs, err := f.api.NewHint(QuoHint, int(resLen), hintInputs...) - if err != nil { - return nil, err - } - - return f.packLimbs(limbs, false), nil -} - -// QuoHint sets z to the quotient x/y for y != 0 and returns z. -// If y == 0, returns an error. -// Quo implements truncated division (like Go); see QuoRem for more details. -func QuoHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { - nbBits, _, x, y, err := parseHintDivInputs(inputs) - if err != nil { - return err - } - z := new(big.Int) - z.Quo(x, y) //.Mod(z, y) - - if err := decompose(z, nbBits, outputs); err != nil { - return fmt.Errorf("decompose: %w", err) - } - - return nil -} - // computeInverseHint packs the inputs for the InverseHint hint function. func (f *Field[T]) computeInverseHint(inLimbs []frontend.Variable) (inverseLimbs []frontend.Variable, err error) { var fp T @@ -232,51 +136,6 @@ func DivHint(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { return nil } -// input[0] = nbBits per limb -// input[1] = nbLimbs(x) -// input[2:2+nbLimbs(x)] = limbs(x) -// input[2+nbLimbs(x):] = limbs(y) -// errors if y == 0 -func parseHintDivInputs(inputs []*big.Int) (uint, int, *big.Int, *big.Int, error) { - if len(inputs) < 2 { - return 0, 0, nil, nil, fmt.Errorf("at least 2 inputs required") - } - nbBits := uint(inputs[0].Uint64()) - nbLimbs := int(inputs[1].Int64()) - if len(inputs[2:]) < nbLimbs { - return 0, 0, nil, nil, fmt.Errorf("x limbs missing") - } - x, y := new(big.Int), new(big.Int) - if err := recompose(inputs[2:2+nbLimbs], nbBits, x); err != nil { - return 0, 0, nil, nil, fmt.Errorf("recompose x: %w", err) - } - if err := recompose(inputs[2+nbLimbs:], nbBits, y); err != nil { - return 0, 0, nil, nil, fmt.Errorf("recompose y: %w", err) - } - if y.IsUint64() && y.Uint64() == 0 { - return 0, 0, nil, nil, fmt.Errorf("y == 0") - } - return nbBits, nbLimbs, x, y, nil -} - -// RightShift shifts input by the given number of bits. Expects two inputs: -// - first input is the shift, will be represented as uint64; -// - second input is the value to be shifted. -// -// Returns a single output which is the value shifted. Errors if number of -// inputs is not 2 and number of outputs is not 1. -func RightShift(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { - if len(inputs) != 2 { - return fmt.Errorf("expecting two inputs") - } - if len(outputs) != 1 { - return fmt.Errorf("expecting single output") - } - shift := inputs[0].Uint64() - outputs[0].Rsh(inputs[1], uint(shift)) - return nil -} - // SqrtHint compute square root of the input. func SqrtHint(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { return UnwrapHint(inputs, outputs, func(field *big.Int, inputs, outputs []*big.Int) error { From 537378e807c6a6cc0843a41e6c984f5fd0530b4b Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Wed, 21 Feb 2024 15:16:42 +0000 Subject: [PATCH 08/15] docs: equality assertion --- std/math/emulated/doc.go | 68 +++++----------------------------------- 1 file changed, 8 insertions(+), 60 deletions(-) diff --git a/std/math/emulated/doc.go b/std/math/emulated/doc.go index cb04ce4a38..61a6e54288 100644 --- a/std/math/emulated/doc.go +++ b/std/math/emulated/doc.go @@ -88,7 +88,8 @@ The complexity of native limb-wise multiplication is k^2. This translates directly to the complexity in the number of constraints in the constraint system. -For multiplication, we would instead use polynomial representation of the elements: +For multiplication, we would instead use polynomial representation of the +elements: x = ∑_{i=0}^k x_i 2^{w i} y = ∑_{i=0}^k y_i 2^{w i}. @@ -140,68 +141,15 @@ larger than every limb of b. The subtraction is performed as # Equality checking -The package provides two ways to check equality -- limb-wise equality check and -checking equality by value. +Equality checking is performed using modular multiplication. To check that a, b +are equal modulo r, we compute -In the limb-wise equality check we check that the integer values of the elements -x and y are equal. We have to carry the excess using bit decomposition (which -makes the computation fairly inefficient). To reduce the number of bit -decompositions, we instead carry over the excess of the difference of the limbs -instead. As we take the difference, then similarly as computing the padding in -subtraction algorithm, we need to add padding to the limbs before subtracting -limb-wise to avoid underflows. However, the padding in this case is slightly -different -- we do not need the padding to be divisible by the modulus, but -instead need that the limb padding is larger than the limb which is being -subtracted. + diff = b-a, -Lets look at the algorithm itself. We assume that the overflow f of x is larger -than y. If overflow of y is larger, then we can just swap the arguments and -apply the same argumentation. Let +and enforce modular multiplication check using the techniques for modular +multiplication: - maxValue = 1 << (k+f), // padding for limbs - maxValueShift = 1 << f. // carry part of the padding - -For every limb we compute the difference as - - diff_0 = maxValue+x_0-y_0, - diff_i = maxValue+carry_i+x_i-y_i-maxValueShift. - -We check that the normal part of the difference is zero and carry the rest over -to next limb: - - diff_i[0:k] == 0, - carry_{i+1} = diff_i[k:k+f+1] // we also carry over the padding bit. - -Finally, after we have compared all the limbs, we still need to check that the -final carry corresponds to the padding. We add final check: - - carry_k == maxValueShift. - -We can further optimise the limb-wise equality check by first regrouping the -limbs. The idea is to group several limbs so that the result would still fit -into the scalar field. If - - x = ∑_{i=0}^k x_i 2^{w i}, - -then we can instead take w' divisible by w such that - - x = ∑_{i=0}^(k/(w'/w)) x'_i 2^{w' i}, - -where - - x'_j = ∑_{i=0}^(w'/w) x_{j*w'/w+i} 2^{w i}. - -For element value equality check, we check that two elements x and y are equal -modulo r and for that we need to show that r divides x-y. As mentioned in the -subtraction section, we add sufficient padding such that x-y does not underflow -and its integer value is always larger than 0. We use hint function to compute z -such that - - x-y = z*r, - -compute z*r and use limbwise equality checking to show that - - x-y == z*r. + diff * 1 = 0 + k * r. # Bitwidth enforcement From b4c68446bfaabd41cae6c9b041567300a0c48dc2 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Wed, 21 Feb 2024 23:43:52 +0000 Subject: [PATCH 09/15] fix: deduce maximum degree from all mulcheck inputs --- std/math/emulated/field_mul.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index b0f395b9a0..89fbc1a8d2 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -194,9 +194,10 @@ func (f *Field[T]) performMulChecks(api frontend.API) error { // we give all the inputs as inputs to obtain random verifier challenge. multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { // for efficiency, we compute all powers of the challenge as slice at. - coefsLen := 0 + coefsLen := int(f.fParams.NbLimbs()) for i := range f.mulChecks { - coefsLen = max(coefsLen, len(f.mulChecks[i].c.Limbs)) + coefsLen = max(coefsLen, len(f.mulChecks[i].a.Limbs), len(f.mulChecks[i].b.Limbs), + len(f.mulChecks[i].c.Limbs), len(f.mulChecks[i].k.Limbs)) } at := make([]frontend.Variable, coefsLen) at[0] = commitment From 0b091f05d93af0249b9ced52587e2732a48acfb4 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Wed, 21 Feb 2024 23:44:13 +0000 Subject: [PATCH 10/15] test: enable all mul tests --- std/math/emulated/element_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/std/math/emulated/element_test.go b/std/math/emulated/element_test.go index fec7b6e0ba..c8e5c817b5 100644 --- a/std/math/emulated/element_test.go +++ b/std/math/emulated/element_test.go @@ -153,9 +153,9 @@ func (c *MulNoOverflowCircuit[T]) Define(api frontend.API) error { } func TestMulCircuitNoOverflow(t *testing.T) { - // testMulCircuitNoOverflow[Goldilocks](t) + testMulCircuitNoOverflow[Goldilocks](t) testMulCircuitNoOverflow[Secp256k1Fp](t) - // testMulCircuitNoOverflow[BN254Fp](t) + testMulCircuitNoOverflow[BN254Fp](t) } func testMulCircuitNoOverflow[T FieldParams](t *testing.T) { From 93f45a7eb9917ba5e10d861c0ece95a5d109b459 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Thu, 22 Feb 2024 00:02:06 +0000 Subject: [PATCH 11/15] chore: stats --- internal/stats/latest.stats | Bin 3067 -> 3067 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/internal/stats/latest.stats b/internal/stats/latest.stats index 34a6ea8798404b732de31db44b1de52f1193f2a4..c813f57bd38adefd83563f8c2a8c3d34214f9871 100644 GIT binary patch delta 430 zcmew@{#$%P=43slhm#XH{!GqhUozQ+t7Y<1)`yeB8Lv;wd^ee&`2nN&!g6tN;K2 delta 470 zcmew@{#$%P=43V2ca!~Du20N-I9Y`4!DL;I*2(>xcPHOumow4L3S<21rErM>3OE@5 zZvFj;@$YKEO*n-n=P*B>{DtkyG4T3h07~Lv zD1QU0k-(ycfx`~3M!1={Oql$FsV#&L9yH}<=4OT{VUrwL#Q1md^(aI@{9&xBW@7xq PnDlZU9>K{QIPL=go3Ekl From ac1a7283ecc3685ff371dafff807656075a8153f Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Wed, 28 Feb 2024 18:42:06 +0000 Subject: [PATCH 12/15] refactor: generic impl for assert/mul --- std/math/emulated/field_mul.go | 100 +++++++++++---------------------- 1 file changed, 34 insertions(+), 66 deletions(-) diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index 89fbc1a8d2..963cc45eb6 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -106,7 +106,7 @@ func (mc *mulCheck[T]) cleanEvaluations() { func (f *Field[T]) mulMod(a, b *Element[T], _ uint) *Element[T] { f.enforceWidthConditional(a) f.enforceWidthConditional(b) - k, r, c, err := f.callMulHint(a, b) + k, r, c, err := f.callMulHint(a, b, true) if err != nil { panic(err) } @@ -127,14 +127,15 @@ func (f *Field[T]) checkZero(a *Element[T]) { // the method works similarly to mulMod, but we know that we are multiplying // by one and expected result should be zero. f.enforceWidthConditional(a) - k, r, c, err := f.callCheckZeroHint(a) + b := f.shortOne() + k, r, c, err := f.callMulHint(a, b, false) if err != nil { panic(err) } mc := mulCheck[T]{ f: f, a: a, - b: f.shortOne(), // one on single limb to speed up the polynomial evaluation + b: b, // one on single limb to speed up the polynomial evaluation c: c, k: k, r: r, // expected to be zero on zero limbs. @@ -233,7 +234,7 @@ func (f *Field[T]) performMulChecks(api frontend.API) error { } // callMulHint uses hint to compute r, k and c. -func (f *Field[T]) callMulHint(a, b *Element[T]) (quo, rem, carries *Element[T], err error) { +func (f *Field[T]) callMulHint(a, b *Element[T], packRem bool) (quo, rem, carries *Element[T], err error) { // inputs is always nblimbs // quotient may be larger if inputs have overflow // remainder is always nblimbs @@ -242,15 +243,17 @@ func (f *Field[T]) callMulHint(a, b *Element[T]) (quo, rem, carries *Element[T], // skip error handle - it happens when we are supposed to reduce. But we // already check it as a precondition. We only need the overflow here. nbLimbs, nbBits := f.fParams.NbLimbs(), f.fParams.BitsPerLimb() - nbQuoLimbs := ((2*nbLimbs-1)*nbBits + nextOverflow + 1 - // + nbQuoLimbs := (uint(nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs)))*nbBits + nextOverflow + 1 - // uint(f.fParams.Modulus().BitLen()) + // nbBits - 1) / nbBits nbRemLimbs := nbLimbs - nbCarryLimbs := (nbQuoLimbs + nbLimbs) - 2 + nbCarryLimbs := max(nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs)), nbMultiplicationResLimbs(int(nbQuoLimbs), int(nbLimbs))) - 1 hintInputs := []frontend.Variable{ nbBits, nbLimbs, + len(a.Limbs), + nbQuoLimbs, } hintInputs = append(hintInputs, f.Modulus().Limbs...) hintInputs = append(hintInputs, a.Limbs...) @@ -261,70 +264,29 @@ func (f *Field[T]) callMulHint(a, b *Element[T]) (quo, rem, carries *Element[T], return } quo = f.packLimbs(ret[:nbQuoLimbs], false) - rem = f.packLimbs(ret[nbQuoLimbs:nbQuoLimbs+nbRemLimbs], true) - carries = f.newInternalElement(ret[nbQuoLimbs+nbRemLimbs:], 0) - return -} - -func (f *Field[T]) callCheckZeroHint(a *Element[T]) (quo, rem, carries *Element[T], err error) { - nbLimbs, nbBits := f.fParams.NbLimbs(), f.fParams.BitsPerLimb() - // this is the same as callMulHint, but we know that we multiply a by one. - // This allows to optimize some bounds - namely we know that the overflow of - // the product would be the same as the overflow of a as we would be - // multiplying every limb of a by one. Secondly, the number of limbs of the - // result is also the same as for a. - nextOverflow := a.overflow - nbActualQuoLimbs := (uint(len(a.Limbs))*nbBits + nextOverflow + 1 - // - uint(f.fParams.Modulus().BitLen()) + // - nbBits - 1) / - nbBits - - // also compute the number of limbs assuming that we have 1 on full - // number of limbs to keep compatibility with the multiplication hint - // (which assumes that the number of limbs is the same as the input). - // Otherwise, we would have to provide additional inputs to indicate the - // length of the quotient. - nbFullQuoLimbs := ((2*nbLimbs-1)*nbBits + nextOverflow + 1 - // - uint(f.fParams.Modulus().BitLen()) + // - nbBits - 1) / - nbBits - nbRemLimbs := nbLimbs - nbCarryLimbs := (nbFullQuoLimbs + nbLimbs) - 2 - hintInputs := []frontend.Variable{ - nbBits, - nbLimbs, - } - hintInputs = append(hintInputs, f.Modulus().Limbs...) - hintInputs = append(hintInputs, a.Limbs...) - hintInputs = append(hintInputs, f.One().Limbs...) - ret, err := f.api.NewHint(mulHint, int(nbFullQuoLimbs)+int(nbRemLimbs)+int(nbCarryLimbs), hintInputs...) - if err != nil { - err = fmt.Errorf("call hint: %w", err) - return + if packRem { + rem = f.packLimbs(ret[nbQuoLimbs:nbQuoLimbs+nbRemLimbs], true) + } else { + rem = &Element[T]{} } - // now, we only pack the number of actual expected non-zero limbs for - // quotient. This makes later in the multiplication check later cheaper as - // we evaluate the polynomial with limbs as coefficients. - quo = f.packLimbs(ret[:nbActualQuoLimbs], false) - // and the remainder is supposed to be 0. As a polynomial this means it will - // be evaluated to zero. - rem = &Element[T]{} - carries = f.newInternalElement(ret[nbFullQuoLimbs+nbRemLimbs:], 0) + carries = f.newInternalElement(ret[nbQuoLimbs+nbRemLimbs:], 0) return } func mulHint(field *big.Int, inputs, outputs []*big.Int) error { nbBits := int(inputs[0].Int64()) nbLimbs := int(inputs[1].Int64()) - ptr := 2 + nbALen := int(inputs[2].Int64()) + nbQuoLen := int(inputs[3].Int64()) + nbBLen := len(inputs) - 4 - nbLimbs - nbALen + ptr := 4 plimbs := inputs[ptr : ptr+nbLimbs] ptr += nbLimbs - alimbs := inputs[ptr : ptr+nbLimbs] - ptr += nbLimbs - blimbs := inputs[ptr : ptr+nbLimbs] + alimbs := inputs[ptr : ptr+nbALen] + ptr += nbALen + blimbs := inputs[ptr : ptr+nbBLen] - nbQuoLen := (len(outputs) - 2*nbLimbs + 2) / 2 - nbCarryLen := nbLimbs + nbQuoLen - 2 + nbCarryLen := max(nbMultiplicationResLimbs(nbALen, nbBLen), nbMultiplicationResLimbs(nbQuoLen, nbLimbs)) - 1 outptr := 0 quoLimbs := outputs[outptr : outptr+nbQuoLen] outptr += nbQuoLen @@ -354,8 +316,8 @@ func mulHint(field *big.Int, inputs, outputs []*big.Int) error { if err := decompose(rem, uint(nbBits), remLimbs); err != nil { return fmt.Errorf("decompose rem: %w", err) } - xp := make([]*big.Int, nbLimbs+nbQuoLen-1) - yp := make([]*big.Int, nbLimbs+nbQuoLen-1) + xp := make([]*big.Int, nbMultiplicationResLimbs(nbALen, nbBLen)) + yp := make([]*big.Int, nbMultiplicationResLimbs(nbQuoLen, nbLimbs)) for i := range xp { xp[i] = new(big.Int) } @@ -363,11 +325,13 @@ func mulHint(field *big.Int, inputs, outputs []*big.Int) error { yp[i] = new(big.Int) } tmp := new(big.Int) - for i := 0; i < nbLimbs; i++ { - for j := 0; j < nbLimbs; j++ { + for i := 0; i < nbALen; i++ { + for j := 0; j < nbBLen; j++ { tmp.Mul(alimbs[i], blimbs[j]) xp[i+j].Add(xp[i+j], tmp) } + } + for i := 0; i < nbLimbs; i++ { yp[i].Add(yp[i], remLimbs[i]) for j := 0; j < nbQuoLen; j++ { tmp.Mul(quoLimbs[j], plimbs[i]) @@ -376,8 +340,12 @@ func mulHint(field *big.Int, inputs, outputs []*big.Int) error { } carry := new(big.Int) for i := range carryLimbs { - carry.Add(carry, xp[i]) - carry.Sub(carry, yp[i]) + if i < len(xp) { + carry.Add(carry, xp[i]) + } + if i < len(yp) { + carry.Sub(carry, yp[i]) + } carry.Rsh(carry, uint(nbBits)) carryLimbs[i] = new(big.Int).Set(carry) } From 41b1a89ff70b9019b4f82cc31af893faff78c892 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Wed, 28 Feb 2024 18:43:40 +0000 Subject: [PATCH 13/15] fix: mul pre cond overflow computation --- std/math/emulated/field_mul.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index 963cc45eb6..70b175010f 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -414,7 +414,7 @@ func (f *Field[T]) mulPreCond(a, b *Element[T]) (nextOverflow uint, err error) { nbResLimbs := nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs)) nbLimbsOverflow := uint(1) if nbResLimbs > 0 { - nbLimbsOverflow = uint(bits.Len(uint(2*nbResLimbs - 1))) + nbLimbsOverflow = uint(bits.Len(uint(nbResLimbs))) } nextOverflow = f.fParams.BitsPerLimb() + nbLimbsOverflow + a.overflow + b.overflow if nextOverflow > f.maxOverflow() { From d0e8354ca18b8c23b8a453fe4d9fe1e58bf50c51 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Wed, 28 Feb 2024 22:43:12 +0000 Subject: [PATCH 14/15] docs: comments --- std/math/emulated/field_mul.go | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index 70b175010f..3b3235e5cb 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -234,21 +234,33 @@ func (f *Field[T]) performMulChecks(api frontend.API) error { } // callMulHint uses hint to compute r, k and c. -func (f *Field[T]) callMulHint(a, b *Element[T], packRem bool) (quo, rem, carries *Element[T], err error) { - // inputs is always nblimbs - // quotient may be larger if inputs have overflow - // remainder is always nblimbs - // carries is 2 * nblimbs - 2 (do not consider first limb) +func (f *Field[T]) callMulHint(a, b *Element[T], isMulMod bool) (quo, rem, carries *Element[T], err error) { + // compute the expected overflow after the multiplication of a*b to be able + // to estimate the number of bits required to represent the result. nextOverflow, _ := f.mulPreCond(a, b) // skip error handle - it happens when we are supposed to reduce. But we // already check it as a precondition. We only need the overflow here. + if !isMulMod { + // b is one on single limb. We do not increase the overflow + nextOverflow = a.overflow + } nbLimbs, nbBits := f.fParams.NbLimbs(), f.fParams.BitsPerLimb() + // we need to compute the number of limbs for the quotient. To compute it, + // we compute the width of the product of a*b, then we divide it by the + // width of the modulus. We add 1 to the result to ensure that we have + // enough space for the quotient. nbQuoLimbs := (uint(nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs)))*nbBits + nextOverflow + 1 - // uint(f.fParams.Modulus().BitLen()) + // nbBits - 1) / nbBits + // the remainder is always less than modulus so can represent on the same + // number of limbs as the modulus. nbRemLimbs := nbLimbs + // we need to compute the number of limbs for the carries. It is maximum of + // the number of limbs of the product of a*b or k*p. nbCarryLimbs := max(nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs)), nbMultiplicationResLimbs(int(nbQuoLimbs), int(nbLimbs))) - 1 + // we encode the computed parameters and widths to the hint function so can + // know how many limbs to expect. hintInputs := []frontend.Variable{ nbBits, nbLimbs, @@ -263,12 +275,19 @@ func (f *Field[T]) callMulHint(a, b *Element[T], packRem bool) (quo, rem, carrie err = fmt.Errorf("call hint: %w", err) return } + // quotient is always range checked according to how many limbs we expect. quo = f.packLimbs(ret[:nbQuoLimbs], false) - if packRem { + // remainder is always range checked when we use it as a result of + // multiplication (and it needs to be strictly less than modulus). However, + // when we use the hint for equality assertion then we assume the result to + // be 0 which can be represented by 0 limbs. + if isMulMod { rem = f.packLimbs(ret[nbQuoLimbs:nbQuoLimbs+nbRemLimbs], true) } else { rem = &Element[T]{} } + // pack the carries into element. Used in the deferred multiplication check + // to align the limbs due to different overflows. carries = f.newInternalElement(ret[nbQuoLimbs+nbRemLimbs:], 0) return } @@ -325,6 +344,8 @@ func mulHint(field *big.Int, inputs, outputs []*big.Int) error { yp[i] = new(big.Int) } tmp := new(big.Int) + // we know compute the schoolbook multiprecision multiplication of a*b and + // r+k*p for i := 0; i < nbALen; i++ { for j := 0; j < nbBLen; j++ { tmp.Mul(alimbs[i], blimbs[j]) From f679a17b013102172d2686a061c4c662adbf4255 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Tue, 5 Mar 2024 12:31:38 +0000 Subject: [PATCH 15/15] chore: stats --- internal/stats/latest.stats | Bin 3067 -> 3067 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/internal/stats/latest.stats b/internal/stats/latest.stats index c813f57bd38adefd83563f8c2a8c3d34214f9871..bd072702861d0b30ae344ccb1d5f86f34d6c5b58 100644 GIT binary patch delta 475 zcmew@{#$%P=HypQEt7RQS|__P#Z7+7E^I7j!T8r>+Xn_H;9&f_F831S-^F)2a0*S% zVSc<>jrAPkG4VBX69ywEF4Jsl74XkjY_zK5cYZ(6;{IWrW&EHasG{(QVA7k(c0RWO; BiWC3< delta 477 zcmew@{#$%P=Hy?j4>!kfaWGD`Jd znMIj->G4T9#fC=47KSVw3{b$q_*bX6hVifG@h~Kz^_|}s|1S0C!YMTIb^GLc&byP3 zv)rA0n@zJmjqx9U90TJ&4l)=F9E|^z-5CEVRR9ejhrz&MN1jHQuE{I7Aniazj^b@@Cfn2L$=(^+zd5blLZbj{$0E<4Wa!HV};T$#y^Y+yN_T9@*s-> E0L&ed9smFU