diff --git a/internal/stats/latest.stats b/internal/stats/latest.stats index 34a6ea8798..bd07270286 100644 Binary files a/internal/stats/latest.stats and b/internal/stats/latest.stats differ diff --git a/std/math/emulated/doc.go b/std/math/emulated/doc.go index cb04ce4a38..61a6e54288 100644 --- a/std/math/emulated/doc.go +++ b/std/math/emulated/doc.go @@ -88,7 +88,8 @@ The complexity of native limb-wise multiplication is k^2. This translates directly to the complexity in the number of constraints in the constraint system. -For multiplication, we would instead use polynomial representation of the elements: +For multiplication, we would instead use polynomial representation of the +elements: x = ∑_{i=0}^k x_i 2^{w i} y = ∑_{i=0}^k y_i 2^{w i}. @@ -140,68 +141,15 @@ larger than every limb of b. The subtraction is performed as # Equality checking -The package provides two ways to check equality -- limb-wise equality check and -checking equality by value. +Equality checking is performed using modular multiplication. To check that a, b +are equal modulo r, we compute -In the limb-wise equality check we check that the integer values of the elements -x and y are equal. We have to carry the excess using bit decomposition (which -makes the computation fairly inefficient). To reduce the number of bit -decompositions, we instead carry over the excess of the difference of the limbs -instead. As we take the difference, then similarly as computing the padding in -subtraction algorithm, we need to add padding to the limbs before subtracting -limb-wise to avoid underflows. However, the padding in this case is slightly -different -- we do not need the padding to be divisible by the modulus, but -instead need that the limb padding is larger than the limb which is being -subtracted. + diff = b-a, -Lets look at the algorithm itself. We assume that the overflow f of x is larger -than y. If overflow of y is larger, then we can just swap the arguments and -apply the same argumentation. Let +and enforce modular multiplication check using the techniques for modular +multiplication: - maxValue = 1 << (k+f), // padding for limbs - maxValueShift = 1 << f. // carry part of the padding - -For every limb we compute the difference as - - diff_0 = maxValue+x_0-y_0, - diff_i = maxValue+carry_i+x_i-y_i-maxValueShift. - -We check that the normal part of the difference is zero and carry the rest over -to next limb: - - diff_i[0:k] == 0, - carry_{i+1} = diff_i[k:k+f+1] // we also carry over the padding bit. - -Finally, after we have compared all the limbs, we still need to check that the -final carry corresponds to the padding. We add final check: - - carry_k == maxValueShift. - -We can further optimise the limb-wise equality check by first regrouping the -limbs. The idea is to group several limbs so that the result would still fit -into the scalar field. If - - x = ∑_{i=0}^k x_i 2^{w i}, - -then we can instead take w' divisible by w such that - - x = ∑_{i=0}^(k/(w'/w)) x'_i 2^{w' i}, - -where - - x'_j = ∑_{i=0}^(w'/w) x_{j*w'/w+i} 2^{w i}. - -For element value equality check, we check that two elements x and y are equal -modulo r and for that we need to show that r divides x-y. As mentioned in the -subtraction section, we add sufficient padding such that x-y does not underflow -and its integer value is always larger than 0. We use hint function to compute z -such that - - x-y = z*r, - -compute z*r and use limbwise equality checking to show that - - x-y == z*r. + diff * 1 = 0 + k * r. # Bitwidth enforcement diff --git a/std/math/emulated/element_test.go b/std/math/emulated/element_test.go index 1bd7521f3e..c8e5c817b5 100644 --- a/std/math/emulated/element_test.go +++ b/std/math/emulated/element_test.go @@ -17,42 +17,11 @@ import ( const testCurve = ecc.BN254 -type AssertLimbEqualityCircuit[T FieldParams] struct { - A, B Element[T] -} - -func (c *AssertLimbEqualityCircuit[T]) Define(api frontend.API) error { - f, err := NewField[T](api) - if err != nil { - return err - } - f.AssertLimbsEquality(&c.A, &c.B) - return nil -} - func testName[T FieldParams]() string { var fp T return fmt.Sprintf("%s/limb=%d", reflect.TypeOf(fp).Name(), fp.BitsPerLimb()) } -func TestAssertLimbEqualityNoOverflow(t *testing.T) { - testAssertLimbEqualityNoOverflow[Goldilocks](t) - testAssertLimbEqualityNoOverflow[Secp256k1Fp](t) - testAssertLimbEqualityNoOverflow[BN254Fp](t) -} - -func testAssertLimbEqualityNoOverflow[T FieldParams](t *testing.T) { - var fp T - assert := test.NewAssert(t) - assert.Run(func(assert *test.Assert) { - var circuit, witness AssertLimbEqualityCircuit[T] - val, _ := rand.Int(rand.Reader, fp.Modulus()) - witness.A = ValueOf[T](val) - witness.B = ValueOf[T](val) - assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness)) - }, testName[T]()) -} - // TODO: add also cases which should fail type AssertIsLessEqualThanCircuit[T FieldParams] struct { @@ -184,9 +153,9 @@ func (c *MulNoOverflowCircuit[T]) Define(api frontend.API) error { } func TestMulCircuitNoOverflow(t *testing.T) { - // testMulCircuitNoOverflow[Goldilocks](t) + testMulCircuitNoOverflow[Goldilocks](t) testMulCircuitNoOverflow[Secp256k1Fp](t) - // testMulCircuitNoOverflow[BN254Fp](t) + testMulCircuitNoOverflow[BN254Fp](t) } func testMulCircuitNoOverflow[T FieldParams](t *testing.T) { diff --git a/std/math/emulated/field.go b/std/math/emulated/field.go index 3480aaab31..6c1f19b04d 100644 --- a/std/math/emulated/field.go +++ b/std/math/emulated/field.go @@ -30,14 +30,16 @@ type Field[T FieldParams] struct { maxOfOnce sync.Once // constants for often used elements n, 0 and 1. Allocated only once - nConstOnce sync.Once - nConst *Element[T] - nprevConstOnce sync.Once - nprevConst *Element[T] - zeroConstOnce sync.Once - zeroConst *Element[T] - oneConstOnce sync.Once - oneConst *Element[T] + nConstOnce sync.Once + nConst *Element[T] + nprevConstOnce sync.Once + nprevConst *Element[T] + zeroConstOnce sync.Once + zeroConst *Element[T] + oneConstOnce sync.Once + oneConst *Element[T] + shortOneConstOnce sync.Once + shortOneConst *Element[T] log zerolog.Logger @@ -146,6 +148,14 @@ func (f *Field[T]) One() *Element[T] { return f.oneConst } +// shortOne returns one as a constant stored in a single limb. +func (f *Field[T]) shortOne() *Element[T] { + f.shortOneConstOnce.Do(func() { + f.shortOneConst = f.newInternalElement([]frontend.Variable{1}, 0) + }) + return f.shortOneConst +} + // Modulus returns the modulus of the emulated ring as a constant. func (f *Field[T]) Modulus() *Element[T] { f.nConstOnce.Do(func() { @@ -248,51 +258,6 @@ func (f *Field[T]) constantValue(v *Element[T]) (*big.Int, bool) { return res, true } -// compact returns parameters which allow for most optimal regrouping of -// limbs. In regrouping the limbs, we encode multiple existing limbs as a linear -// combination in a single new limb. -// compact returns a and b minimal (in number of limbs) representation that fits in the snark field -func (f *Field[T]) compact(a, b *Element[T]) (ac, bc []frontend.Variable, bitsPerLimb uint) { - // omit width reduction as is done in the calling method already - maxOverflow := max(a.overflow, b.overflow) - // subtract one bit as can not potentially use all bits of Fr and one bit as - // grouping may overflow - maxNbBits := uint(f.api.Compiler().FieldBitLen()) - 2 - maxOverflow - groupSize := maxNbBits / f.fParams.BitsPerLimb() - if groupSize == 0 { - // no space for compact - return a.Limbs, b.Limbs, f.fParams.BitsPerLimb() - } - - bitsPerLimb = f.fParams.BitsPerLimb() * groupSize - - ac = f.compactLimbs(a, groupSize, bitsPerLimb) - bc = f.compactLimbs(b, groupSize, bitsPerLimb) - return -} - -// compactLimbs perform the regrouping of limbs between old and new parameters. -func (f *Field[T]) compactLimbs(e *Element[T], groupSize, bitsPerLimb uint) []frontend.Variable { - if f.fParams.BitsPerLimb() == bitsPerLimb { - return e.Limbs - } - nbLimbs := (uint(len(e.Limbs)) + groupSize - 1) / groupSize - r := make([]frontend.Variable, nbLimbs) - coeffs := make([]*big.Int, groupSize) - one := big.NewInt(1) - for i := range coeffs { - coeffs[i] = new(big.Int) - coeffs[i].Lsh(one, f.fParams.BitsPerLimb()*uint(i)) - } - for i := uint(0); i < nbLimbs; i++ { - r[i] = uint(0) - for j := uint(0); j < groupSize && i*groupSize+j < uint(len(e.Limbs)); j++ { - r[i] = f.api.Add(r[i], f.api.Mul(coeffs[j], e.Limbs[i*groupSize+j])) - } - } - return r -} - // maxOverflow returns the maximal possible overflow for the element. If the // overflow of the next operation exceeds the value returned by this method, // then the limbs may overflow the native field. diff --git a/std/math/emulated/field_assert.go b/std/math/emulated/field_assert.go index 692db8ef2e..a2809e4eb9 100644 --- a/std/math/emulated/field_assert.go +++ b/std/math/emulated/field_assert.go @@ -2,93 +2,10 @@ package emulated import ( "fmt" - "math/big" "github.com/consensys/gnark/frontend" ) -// assertLimbsEqualitySlow is the main routine in the package. It asserts that the -// two slices of limbs represent the same integer value. This is also the most -// costly operation in the package as it does bit decomposition of the limbs. -func (f *Field[T]) assertLimbsEqualitySlow(api frontend.API, l, r []frontend.Variable, nbBits, nbCarryBits uint) { - - nbLimbs := max(len(l), len(r)) - maxValue := new(big.Int).Lsh(big.NewInt(1), nbBits+nbCarryBits) - maxValueShift := new(big.Int).Lsh(big.NewInt(1), nbCarryBits) - - var carry frontend.Variable = 0 - for i := 0; i < nbLimbs; i++ { - diff := api.Add(maxValue, carry) - if i < len(l) { - diff = api.Add(diff, l[i]) - } - if i < len(r) { - diff = api.Sub(diff, r[i]) - } - if i > 0 { - diff = api.Sub(diff, maxValueShift) - } - - // carry is stored in the highest bits of diff[nbBits:nbBits+nbCarryBits+1] - // we know that diff[:nbBits] are 0 bits, but still need to constrain them. - // to do both; we do a "clean" right shift and only need to boolean constrain the carry part - carry = f.rsh(diff, int(nbBits), int(nbBits+nbCarryBits+1)) - } - api.AssertIsEqual(carry, maxValueShift) -} - -func (f *Field[T]) rsh(v frontend.Variable, startDigit, endDigit int) frontend.Variable { - // if v is a constant, work with the big int value. - if c, ok := f.api.Compiler().ConstantValue(v); ok { - bits := make([]frontend.Variable, endDigit-startDigit) - for i := 0; i < len(bits); i++ { - bits[i] = c.Bit(i + startDigit) - } - return bits - } - shifted, err := f.api.Compiler().NewHint(RightShift, 1, startDigit, v) - if err != nil { - panic(fmt.Sprintf("right shift: %v", err)) - } - f.checker.Check(shifted[0], endDigit-startDigit) - shift := new(big.Int).Lsh(big.NewInt(1), uint(startDigit)) - composed := f.api.Mul(shifted[0], shift) - f.api.AssertIsEqual(composed, v) - return shifted[0] -} - -// AssertLimbsEquality asserts that the limbs represent a same integer value. -// This method does not ensure that the values are equal modulo the field order. -// For strict equality, use AssertIsEqual. -func (f *Field[T]) AssertLimbsEquality(a, b *Element[T]) { - f.enforceWidthConditional(a) - f.enforceWidthConditional(b) - ba, aConst := f.constantValue(a) - bb, bConst := f.constantValue(b) - if aConst && bConst { - ba.Mod(ba, f.fParams.Modulus()) - bb.Mod(bb, f.fParams.Modulus()) - if ba.Cmp(bb) != 0 { - panic(fmt.Errorf("constant values are different: %s != %s", ba.String(), bb.String())) - } - return - } - - // first, we check if we can compact a and b; they could be using 8 limbs of 32bits - // but with our snark field, we could express them in 2 limbs of 128bits, which would make bit decomposition - // and limbs equality in-circuit (way) cheaper - ca, cb, bitsPerLimb := f.compact(a, b) - - // slow path -- the overflows are different. Need to compare with carries. - // TODO: we previously assumed that one side was "larger" than the other - // side, but I think this assumption is not valid anymore - if a.overflow > b.overflow { - f.assertLimbsEqualitySlow(f.api, ca, cb, bitsPerLimb, a.overflow) - } else { - f.assertLimbsEqualitySlow(f.api, cb, ca, bitsPerLimb, b.overflow) - } -} - // enforceWidth enforces the width of the limbs. When modWidth is true, then the // limbs are asserted to be the width of the modulus (highest limb may be less // than full limb width). Otherwise, every limb is assumed to have same width @@ -129,19 +46,7 @@ func (f *Field[T]) AssertIsEqual(a, b *Element[T]) { } diff := f.Sub(b, a) - - // we compute k such that diff / p == k - // so essentially, we say "I know an element k such that k*p == diff" - // hence, diff == 0 mod p - p := f.Modulus() - k, err := f.computeQuoHint(diff) - if err != nil { - panic(fmt.Sprintf("hint error: %v", err)) - } - - kp := f.reduceAndOp(f.mul, f.mulPreCond, k, p) - - f.AssertLimbsEquality(diff, kp) + f.checkZero(diff) } // AssertIsLessOrEqual ensures that e is less or equal than a. For proper @@ -196,11 +101,31 @@ func (f *Field[T]) AssertIsInRange(a *Element[T]) { func (f *Field[T]) IsZero(a *Element[T]) frontend.Variable { ca := f.Reduce(a) f.AssertIsInRange(ca) - res := f.api.IsZero(ca.Limbs[0]) + // we use two approaches for checking if the element is exactly zero. The + // first approach is to check that every limb individually is zero. The + // second approach is to check if the sum of all limbs is zero. Usually, we + // cannot use this approach as we could have false positive due to overflow + // in the native field. However, as the widths of the limbs are restricted, + // then we can ensure in most cases that no overflows happen. + + // as ca is already reduced, then every limb overflow is already 0. Only + // every addition adds a bit to the overflow + totalOverflow := len(ca.Limbs) - 1 + if totalOverflow < int(f.maxOverflow()) { + // the sums of limbs would overflow the native field. Use the first + // approach instead. + res := f.api.IsZero(ca.Limbs[0]) + for i := 1; i < len(ca.Limbs); i++ { + res = f.api.Mul(res, f.api.IsZero(ca.Limbs[i])) + } + return res + } + // default case, limbs sum does not overflow the native field + limbSum := ca.Limbs[0] for i := 1; i < len(ca.Limbs); i++ { - res = f.api.Mul(res, f.api.IsZero(ca.Limbs[i])) + limbSum = f.api.Add(limbSum, ca.Limbs[i]) } - return res + return f.api.IsZero(limbSum) } // // Cmp returns: diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index 964a4f3058..3b3235e5cb 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -106,7 +106,7 @@ func (mc *mulCheck[T]) cleanEvaluations() { func (f *Field[T]) mulMod(a, b *Element[T], _ uint) *Element[T] { f.enforceWidthConditional(a) f.enforceWidthConditional(b) - k, r, c, err := f.callMulHint(a, b) + k, r, c, err := f.callMulHint(a, b, true) if err != nil { panic(err) } @@ -122,6 +122,27 @@ func (f *Field[T]) mulMod(a, b *Element[T], _ uint) *Element[T] { return r } +// checkZero creates multiplication check a * 1 = 0 + k*p. +func (f *Field[T]) checkZero(a *Element[T]) { + // the method works similarly to mulMod, but we know that we are multiplying + // by one and expected result should be zero. + f.enforceWidthConditional(a) + b := f.shortOne() + k, r, c, err := f.callMulHint(a, b, false) + if err != nil { + panic(err) + } + mc := mulCheck[T]{ + f: f, + a: a, + b: b, // one on single limb to speed up the polynomial evaluation + c: c, + k: k, + r: r, // expected to be zero on zero limbs. + } + f.mulChecks = append(f.mulChecks, mc) +} + // evalWithChallenge represents element a as a polynomial a(X) and evaluates at // at[0]. For efficiency, we use already evaluated powers of at[0] given by at. // It stores the evaluation result inside the Element and marks it as evaluated. @@ -133,7 +154,10 @@ func (f *Field[T]) evalWithChallenge(a *Element[T], at []frontend.Variable) *Ele if len(at) < len(a.Limbs)-1 { panic("evaluation powers less than limbs") } - sum := f.api.Mul(a.Limbs[0], 1) // copy because we use MulAcc + var sum frontend.Variable = 0 + if len(a.Limbs) > 0 { + sum = f.api.Mul(a.Limbs[0], 1) // copy because we use MulAcc + } for i := 1; i < len(a.Limbs); i++ { sum = f.api.MulAcc(sum, a.Limbs[i], at[i-1]) } @@ -171,15 +195,15 @@ func (f *Field[T]) performMulChecks(api frontend.API) error { // we give all the inputs as inputs to obtain random verifier challenge. multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { // for efficiency, we compute all powers of the challenge as slice at. - coefsLen := 0 + coefsLen := int(f.fParams.NbLimbs()) for i := range f.mulChecks { - coefsLen = max(coefsLen, len(f.mulChecks[i].c.Limbs)) + coefsLen = max(coefsLen, len(f.mulChecks[i].a.Limbs), len(f.mulChecks[i].b.Limbs), + len(f.mulChecks[i].c.Limbs), len(f.mulChecks[i].k.Limbs)) } at := make([]frontend.Variable, coefsLen) - var prev frontend.Variable = 1 - for i := range at { - at[i] = api.Mul(prev, commitment) - prev = at[i] + at[0] = commitment + for i := 1; i < len(at); i++ { + at[i] = api.Mul(at[i-1], commitment) } // evaluate all r, k, c for i := range f.mulChecks { @@ -210,24 +234,38 @@ func (f *Field[T]) performMulChecks(api frontend.API) error { } // callMulHint uses hint to compute r, k and c. -func (f *Field[T]) callMulHint(a, b *Element[T]) (quo, rem, carries *Element[T], err error) { - // inputs is always nblimbs - // quotient may be larger if inputs have overflow - // remainder is always nblimbs - // carries is 2 * nblimbs - 2 (do not consider first limb) +func (f *Field[T]) callMulHint(a, b *Element[T], isMulMod bool) (quo, rem, carries *Element[T], err error) { + // compute the expected overflow after the multiplication of a*b to be able + // to estimate the number of bits required to represent the result. nextOverflow, _ := f.mulPreCond(a, b) // skip error handle - it happens when we are supposed to reduce. But we // already check it as a precondition. We only need the overflow here. + if !isMulMod { + // b is one on single limb. We do not increase the overflow + nextOverflow = a.overflow + } nbLimbs, nbBits := f.fParams.NbLimbs(), f.fParams.BitsPerLimb() - nbQuoLimbs := ((2*nbLimbs-1)*nbBits + nextOverflow + 1 - // + // we need to compute the number of limbs for the quotient. To compute it, + // we compute the width of the product of a*b, then we divide it by the + // width of the modulus. We add 1 to the result to ensure that we have + // enough space for the quotient. + nbQuoLimbs := (uint(nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs)))*nbBits + nextOverflow + 1 - // uint(f.fParams.Modulus().BitLen()) + // nbBits - 1) / nbBits + // the remainder is always less than modulus so can represent on the same + // number of limbs as the modulus. nbRemLimbs := nbLimbs - nbCarryLimbs := (nbQuoLimbs + nbLimbs) - 2 + // we need to compute the number of limbs for the carries. It is maximum of + // the number of limbs of the product of a*b or k*p. + nbCarryLimbs := max(nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs)), nbMultiplicationResLimbs(int(nbQuoLimbs), int(nbLimbs))) - 1 + // we encode the computed parameters and widths to the hint function so can + // know how many limbs to expect. hintInputs := []frontend.Variable{ nbBits, nbLimbs, + len(a.Limbs), + nbQuoLimbs, } hintInputs = append(hintInputs, f.Modulus().Limbs...) hintInputs = append(hintInputs, a.Limbs...) @@ -237,8 +275,19 @@ func (f *Field[T]) callMulHint(a, b *Element[T]) (quo, rem, carries *Element[T], err = fmt.Errorf("call hint: %w", err) return } + // quotient is always range checked according to how many limbs we expect. quo = f.packLimbs(ret[:nbQuoLimbs], false) - rem = f.packLimbs(ret[nbQuoLimbs:nbQuoLimbs+nbRemLimbs], true) + // remainder is always range checked when we use it as a result of + // multiplication (and it needs to be strictly less than modulus). However, + // when we use the hint for equality assertion then we assume the result to + // be 0 which can be represented by 0 limbs. + if isMulMod { + rem = f.packLimbs(ret[nbQuoLimbs:nbQuoLimbs+nbRemLimbs], true) + } else { + rem = &Element[T]{} + } + // pack the carries into element. Used in the deferred multiplication check + // to align the limbs due to different overflows. carries = f.newInternalElement(ret[nbQuoLimbs+nbRemLimbs:], 0) return } @@ -246,15 +295,17 @@ func (f *Field[T]) callMulHint(a, b *Element[T]) (quo, rem, carries *Element[T], func mulHint(field *big.Int, inputs, outputs []*big.Int) error { nbBits := int(inputs[0].Int64()) nbLimbs := int(inputs[1].Int64()) - ptr := 2 + nbALen := int(inputs[2].Int64()) + nbQuoLen := int(inputs[3].Int64()) + nbBLen := len(inputs) - 4 - nbLimbs - nbALen + ptr := 4 plimbs := inputs[ptr : ptr+nbLimbs] ptr += nbLimbs - alimbs := inputs[ptr : ptr+nbLimbs] - ptr += nbLimbs - blimbs := inputs[ptr : ptr+nbLimbs] + alimbs := inputs[ptr : ptr+nbALen] + ptr += nbALen + blimbs := inputs[ptr : ptr+nbBLen] - nbQuoLen := (len(outputs) - 2*nbLimbs + 2) / 2 - nbCarryLen := nbLimbs + nbQuoLen - 2 + nbCarryLen := max(nbMultiplicationResLimbs(nbALen, nbBLen), nbMultiplicationResLimbs(nbQuoLen, nbLimbs)) - 1 outptr := 0 quoLimbs := outputs[outptr : outptr+nbQuoLen] outptr += nbQuoLen @@ -284,8 +335,8 @@ func mulHint(field *big.Int, inputs, outputs []*big.Int) error { if err := decompose(rem, uint(nbBits), remLimbs); err != nil { return fmt.Errorf("decompose rem: %w", err) } - xp := make([]*big.Int, nbLimbs+nbQuoLen-1) - yp := make([]*big.Int, nbLimbs+nbQuoLen-1) + xp := make([]*big.Int, nbMultiplicationResLimbs(nbALen, nbBLen)) + yp := make([]*big.Int, nbMultiplicationResLimbs(nbQuoLen, nbLimbs)) for i := range xp { xp[i] = new(big.Int) } @@ -293,11 +344,15 @@ func mulHint(field *big.Int, inputs, outputs []*big.Int) error { yp[i] = new(big.Int) } tmp := new(big.Int) - for i := 0; i < nbLimbs; i++ { - for j := 0; j < nbLimbs; j++ { + // we know compute the schoolbook multiprecision multiplication of a*b and + // r+k*p + for i := 0; i < nbALen; i++ { + for j := 0; j < nbBLen; j++ { tmp.Mul(alimbs[i], blimbs[j]) xp[i+j].Add(xp[i+j], tmp) } + } + for i := 0; i < nbLimbs; i++ { yp[i].Add(yp[i], remLimbs[i]) for j := 0; j < nbQuoLen; j++ { tmp.Mul(quoLimbs[j], plimbs[i]) @@ -306,8 +361,12 @@ func mulHint(field *big.Int, inputs, outputs []*big.Int) error { } carry := new(big.Int) for i := range carryLimbs { - carry.Add(carry, xp[i]) - carry.Sub(carry, yp[i]) + if i < len(xp) { + carry.Add(carry, xp[i]) + } + if i < len(yp) { + carry.Sub(carry, yp[i]) + } carry.Rsh(carry, uint(nbBits)) carryLimbs[i] = new(big.Int).Set(carry) } @@ -376,7 +435,7 @@ func (f *Field[T]) mulPreCond(a, b *Element[T]) (nextOverflow uint, err error) { nbResLimbs := nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs)) nbLimbsOverflow := uint(1) if nbResLimbs > 0 { - nbLimbsOverflow = uint(bits.Len(uint(2*nbResLimbs - 1))) + nbLimbsOverflow = uint(bits.Len(uint(nbResLimbs))) } nextOverflow = f.fParams.BitsPerLimb() + nbLimbsOverflow + a.overflow + b.overflow if nextOverflow > f.maxOverflow() { @@ -384,47 +443,3 @@ func (f *Field[T]) mulPreCond(a, b *Element[T]) (nextOverflow uint, err error) { } return } - -func (f *Field[T]) mul(a, b *Element[T], nextOverflow uint) *Element[T] { - // TODO: kept for [AssertIsEqual]. Consider if this can be removed and we - // can use MulMod for equality assertion. - ba, aConst := f.constantValue(a) - bb, bConst := f.constantValue(b) - if aConst && bConst { - ba.Mul(ba, bb).Mod(ba, f.fParams.Modulus()) - return newConstElement[T](ba) - } - - // mulResult contains the result (out of circuit) of a * b school book multiplication - // len(mulResult) == len(a) + len(b) - 1 - mulResult, err := f.computeMultiplicationHint(a.Limbs, b.Limbs) - if err != nil { - panic(fmt.Sprintf("multiplication hint: %s", err)) - } - - // we computed the result of the mul outside the circuit (mulResult) - // and we want to constrain inside the circuit that this injected value - // actually matches the in-circuit a * b values - // create constraints (\sum_{i=0}^{m-1} a_i c^i) * (\sum_{i=0}^{m-1} b_i - // c^i) = (\sum_{i=0}^{2m-2} z_i c^i) for c \in {1, 2m-1} - w := new(big.Int) - for c := 1; c <= len(mulResult); c++ { - w.SetInt64(1) // c^i - l := f.api.Mul(a.Limbs[0], 1) - r := f.api.Mul(b.Limbs[0], 1) - o := f.api.Mul(mulResult[0], 1) - - for i := 1; i < len(mulResult); i++ { - w.Lsh(w, uint(c)) - if i < len(a.Limbs) { - l = f.api.MulAcc(l, a.Limbs[i], w) - } - if i < len(b.Limbs) { - r = f.api.MulAcc(r, b.Limbs[i], w) - } - o = f.api.MulAcc(o, mulResult[i], w) - } - f.api.AssertIsEqual(f.api.Mul(l, r), o) - } - return f.newInternalElement(mulResult, nextOverflow) -} diff --git a/std/math/emulated/field_ops.go b/std/math/emulated/field_ops.go index b2b96e0ace..4115089a8c 100644 --- a/std/math/emulated/field_ops.go +++ b/std/math/emulated/field_ops.go @@ -153,17 +153,6 @@ func (f *Field[T]) Sub(a, b *Element[T]) *Element[T] { return f.reduceAndOp(f.sub, f.subPreCond, a, b) } -// subReduce returns a-b and returns it. Contrary to [Field[T].Sub] method this -// method does not reduce the inputs if the result would overflow. This method -// is currently only used as a subroutine in [Field[T].Reduce] method to avoid -// infinite recursion when we are working exactly on the overflow limits. -func (f *Field[T]) subNoReduce(a, b *Element[T]) *Element[T] { - nextOverflow, _ := f.subPreCond(a, b) - // we ignore error as it only indicates if we should reduce or not. But we - // are in non-reducing version of sub. - return f.sub(a, b, nextOverflow) -} - func (f *Field[T]) subPreCond(a, b *Element[T]) (nextOverflow uint, err error) { reduceRight := a.overflow < (b.overflow + 1) nextOverflow = max(b.overflow+1, a.overflow) + 1 diff --git a/std/math/emulated/hints.go b/std/math/emulated/hints.go index 16a560f9c7..6c1644c407 100644 --- a/std/math/emulated/hints.go +++ b/std/math/emulated/hints.go @@ -19,27 +19,12 @@ func init() { func GetHints() []solver.Hint { return []solver.Hint{ DivHint, - QuoHint, InverseHint, - MultiplicationHint, - RightShift, SqrtHint, mulHint, } } -// computeMultiplicationHint packs the inputs for the MultiplicationHint hint function. -func (f *Field[T]) computeMultiplicationHint(leftLimbs, rightLimbs []frontend.Variable) (mulLimbs []frontend.Variable, err error) { - hintInputs := []frontend.Variable{ - f.fParams.BitsPerLimb(), - len(leftLimbs), - len(rightLimbs), - } - hintInputs = append(hintInputs, leftLimbs...) - hintInputs = append(hintInputs, rightLimbs...) - return f.api.NewHint(MultiplicationHint, nbMultiplicationResLimbs(len(leftLimbs), len(rightLimbs)), hintInputs...) -} - // nbMultiplicationResLimbs returns the number of limbs which fit the // multiplication result. func nbMultiplicationResLimbs(lenLeft, lenRight int) int { @@ -50,87 +35,6 @@ func nbMultiplicationResLimbs(lenLeft, lenRight int) int { return res } -// MultiplicationHint unpacks the factors and parameters from inputs, computes -// the product and stores it in output. See internal method -// computeMultiplicationHint for the input packing. -func MultiplicationHint(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { - if len(inputs) < 3 { - return fmt.Errorf("input must be at least three elements") - } - nbBits := int(inputs[0].Int64()) - if 2*nbBits+1 >= mod.BitLen() { - return fmt.Errorf("can not fit multiplication result into limb of length %d", nbBits) - } - // TODO: check that the scalar field fits 2*nbBits + nbLimbs. 2*nbBits comes - // from multiplication and nbLimbs comes from additions. - // TODO: check that all limbs all fully reduced - nbLimbsLeft := int(inputs[1].Int64()) - // TODO: get the limb length from the input instead of packing into input - nbLimbsRight := int(inputs[2].Int64()) - if len(inputs) != 3+nbLimbsLeft+nbLimbsRight { - return fmt.Errorf("input invalid") - } - if len(outputs) < nbLimbsLeft+nbLimbsRight-1 { - return fmt.Errorf("can not fit multiplication result into %d limbs", len(outputs)) - } - for _, oi := range outputs { - if oi == nil { - return fmt.Errorf("output not initialized") - } - oi.SetUint64(0) - } - tmp := new(big.Int) - for i, li := range inputs[3 : 3+nbLimbsLeft] { - for j, rj := range inputs[3+nbLimbsLeft:] { - outputs[i+j].Add(outputs[i+j], tmp.Mul(li, rj)) - } - } - return nil -} - -// computeQuoHint packs the inputs for QuoHint function and returns z = x / y -// (discards remainder) -func (f *Field[T]) computeQuoHint(x *Element[T]) (z *Element[T], err error) { - var fp T - resLen := (uint(len(x.Limbs))*fp.BitsPerLimb() + x.overflow + 1 - // diff total bitlength - uint(fp.Modulus().BitLen()) + // subtract modulus bitlength - fp.BitsPerLimb() - 1) / // to round up - fp.BitsPerLimb() - - hintInputs := []frontend.Variable{ - fp.BitsPerLimb(), - len(x.Limbs), - } - p := f.Modulus() - hintInputs = append(hintInputs, x.Limbs...) - hintInputs = append(hintInputs, p.Limbs...) - - limbs, err := f.api.NewHint(QuoHint, int(resLen), hintInputs...) - if err != nil { - return nil, err - } - - return f.packLimbs(limbs, false), nil -} - -// QuoHint sets z to the quotient x/y for y != 0 and returns z. -// If y == 0, returns an error. -// Quo implements truncated division (like Go); see QuoRem for more details. -func QuoHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { - nbBits, _, x, y, err := parseHintDivInputs(inputs) - if err != nil { - return err - } - z := new(big.Int) - z.Quo(x, y) //.Mod(z, y) - - if err := decompose(z, nbBits, outputs); err != nil { - return fmt.Errorf("decompose: %w", err) - } - - return nil -} - // computeInverseHint packs the inputs for the InverseHint hint function. func (f *Field[T]) computeInverseHint(inLimbs []frontend.Variable) (inverseLimbs []frontend.Variable, err error) { var fp T @@ -232,51 +136,6 @@ func DivHint(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { return nil } -// input[0] = nbBits per limb -// input[1] = nbLimbs(x) -// input[2:2+nbLimbs(x)] = limbs(x) -// input[2+nbLimbs(x):] = limbs(y) -// errors if y == 0 -func parseHintDivInputs(inputs []*big.Int) (uint, int, *big.Int, *big.Int, error) { - if len(inputs) < 2 { - return 0, 0, nil, nil, fmt.Errorf("at least 2 inputs required") - } - nbBits := uint(inputs[0].Uint64()) - nbLimbs := int(inputs[1].Int64()) - if len(inputs[2:]) < nbLimbs { - return 0, 0, nil, nil, fmt.Errorf("x limbs missing") - } - x, y := new(big.Int), new(big.Int) - if err := recompose(inputs[2:2+nbLimbs], nbBits, x); err != nil { - return 0, 0, nil, nil, fmt.Errorf("recompose x: %w", err) - } - if err := recompose(inputs[2+nbLimbs:], nbBits, y); err != nil { - return 0, 0, nil, nil, fmt.Errorf("recompose y: %w", err) - } - if y.IsUint64() && y.Uint64() == 0 { - return 0, 0, nil, nil, fmt.Errorf("y == 0") - } - return nbBits, nbLimbs, x, y, nil -} - -// RightShift shifts input by the given number of bits. Expects two inputs: -// - first input is the shift, will be represented as uint64; -// - second input is the value to be shifted. -// -// Returns a single output which is the value shifted. Errors if number of -// inputs is not 2 and number of outputs is not 1. -func RightShift(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { - if len(inputs) != 2 { - return fmt.Errorf("expecting two inputs") - } - if len(outputs) != 1 { - return fmt.Errorf("expecting single output") - } - shift := inputs[0].Uint64() - outputs[0].Rsh(inputs[1], uint(shift)) - return nil -} - // SqrtHint compute square root of the input. func SqrtHint(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { return UnwrapHint(inputs, outputs, func(field *big.Int, inputs, outputs []*big.Int) error {