Skip to content

Commit

Permalink
perf: non-native modular multiplication (#749)
Browse files Browse the repository at this point in the history
* feat: use squaring instead of mimcs for commitment expansion

* feat: cache finite field APIs

* feat: add mulmod by poly evaluation

* refactor: move multiplication

* feat: use mulmod for reduction

* fix: clean evaluations after performing mulchecks

* fix: constant strict width check

* perf: update stats

* docs: update package documentation
  • Loading branch information
ivokub authored Dec 4, 2023
1 parent 808a8f4 commit 64c88cb
Show file tree
Hide file tree
Showing 9 changed files with 471 additions and 216 deletions.
Binary file modified internal/stats/latest.stats
Binary file not shown.
47 changes: 12 additions & 35 deletions std/math/emulated/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,52 +86,29 @@ then the overflow value f' for the sum is computed as
The complexity of native limb-wise multiplication is k^2. This translates
directly to the complexity in the number of constraints in the constraint
system. However, alternatively, when instead computing the limb values
off-circuit and constructing a system of k linear equations, we can ensure that
the product was computed correctly.
system.
Let the factors be
For multiplication, we would instead use polynomial representation of the elements:
x = ∑_{i=0}^k x_i 2^{w i}
and
y = ∑_{i=0}^k y_i 2^{w i}.
For computing the product, we compute off-circuit the limbs
z_i = ∑_{j, j'>0, j+j'=i, j+j'≤2k-2} x_{j} y_{j'}, // in MultiplicationHint()
and assert in-circuit
∑_{i=0}^{2k-2} z_i c^i = (∑_{i=0}^k x_i) (∑_{i=0}^k y_i), ∀ c ∈ {1, ..., 2k-1}.
Computing the overflow for the multiplication result is slightly more
complicated. The overflow for
x_{j} y_{j'}
is
w+f+f'+1.
Naively, as the limbs of the result are summed over all 0 ≤ i ≤ 2k-2, then the
overflow of the limbs should be
w+f+f'+2k-1.
as
For computing the number of bits and thus in the overflow, we can instead look
at the maximal possible value. This can be computed by
x(X) = ∑_{i=0}^k x_i X^i
y(X) = ∑_{i=0}^k y_i X^i.
(2^{2w+f+f'+2}-1)*(2k-1).
If the multiplication result modulo r is c, then the following holds:
Its bitlength is
x * y = c + z*r.
2w+f+f'+1+log_2(2k-1),
We can check the correctness of the multiplication by checking the following
identity at a random point:
which leads to maximal overflow of
x(X) * y(X) = c(X) + z(X) * r(X) + (2^w' - X) e(X),
w+f+f'+1+log_2(2k-1).
where e(X) is a polynomial used for carrying the overflows of the left- and
right-hand side of the above equation.
# Subtraction
Expand Down
3 changes: 3 additions & 0 deletions std/math/emulated/element.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ type Element[T FieldParams] struct {
// ensure that the limbs are width-constrained. We do not store the
// enforcement info in the Element to prevent modifying the witness.
internal bool

isEvaluated bool
evaluation frontend.Variable `gnark:"-"`
}

// ValueOf returns an Element[T] from a constant value.
Expand Down
15 changes: 15 additions & 0 deletions std/math/emulated/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"sync"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/internal/kvstore"
"github.com/consensys/gnark/internal/utils"
"github.com/consensys/gnark/logger"
"github.com/consensys/gnark/std/rangecheck"
Expand Down Expand Up @@ -42,8 +43,12 @@ type Field[T FieldParams] struct {

constrainedLimbs map[uint64]struct{}
checker frontend.Rangechecker

mulChecks []mulCheck[T]
}

type ctxKey[T FieldParams] struct{}

// NewField returns an object to be used in-circuit to perform emulated
// arithmetic over the field defined by type parameter [FieldParams]. The
// operations on this type are defined on [Element]. There is also another type
Expand All @@ -53,6 +58,12 @@ type Field[T FieldParams] struct {
// This is an experimental feature and performing emulated arithmetic in-circuit
// is extremly costly. See package doc for more info.
func NewField[T FieldParams](native frontend.API) (*Field[T], error) {
if storer, ok := native.(kvstore.Store); ok {
ff := storer.GetKeyValue(ctxKey[T]{})
if ff, ok := ff.(*Field[T]); ok {
return ff, nil
}
}
f := &Field[T]{
api: native,
log: logger.Logger(),
Expand Down Expand Up @@ -89,6 +100,10 @@ func NewField[T FieldParams](native frontend.API) (*Field[T], error) {
return nil, fmt.Errorf("elements with limb length %d does not fit into scalar field", f.fParams.BitsPerLimb())
}

native.Compiler().Defer(f.performMulChecks)
if storer, ok := native.(kvstore.Store); ok {
storer.SetKeyValue(ctxKey[T]{}, f)
}
return f, nil
}

Expand Down
2 changes: 1 addition & 1 deletion std/math/emulated/field_assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func (f *Field[T]) AssertLimbsEquality(a, b *Element[T]) {
// (defined by the field parameter).
func (f *Field[T]) enforceWidth(a *Element[T], modWidth bool) {
if _, aConst := f.constantValue(a); aConst {
if len(a.Limbs) != int(f.fParams.NbLimbs()) {
if modWidth && len(a.Limbs) != int(f.fParams.NbLimbs()) {
panic("constant limb width doesn't match parametrized field")
}
}
Expand Down
Loading

0 comments on commit 64c88cb

Please sign in to comment.