Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: non-native modular multiplication #749

Merged
merged 9 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified internal/stats/latest.stats
Binary file not shown.
47 changes: 12 additions & 35 deletions std/math/emulated/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,52 +86,29 @@ then the overflow value f' for the sum is computed as

The complexity of native limb-wise multiplication is k^2. This translates
directly to the complexity in the number of constraints in the constraint
system. However, alternatively, when instead computing the limb values
off-circuit and constructing a system of k linear equations, we can ensure that
the product was computed correctly.
system.

Let the factors be
For multiplication, we would instead use polynomial representation of the elements:

x = ∑_{i=0}^k x_i 2^{w i}

and

y = ∑_{i=0}^k y_i 2^{w i}.

For computing the product, we compute off-circuit the limbs

z_i = ∑_{j, j'>0, j+j'=i, j+j'≤2k-2} x_{j} y_{j'}, // in MultiplicationHint()

and assert in-circuit

∑_{i=0}^{2k-2} z_i c^i = (∑_{i=0}^k x_i) (∑_{i=0}^k y_i), ∀ c ∈ {1, ..., 2k-1}.

Computing the overflow for the multiplication result is slightly more
complicated. The overflow for

x_{j} y_{j'}

is

w+f+f'+1.

Naively, as the limbs of the result are summed over all 0 ≤ i ≤ 2k-2, then the
overflow of the limbs should be

w+f+f'+2k-1.
as

For computing the number of bits and thus in the overflow, we can instead look
at the maximal possible value. This can be computed by
x(X) = ∑_{i=0}^k x_i X^i
y(X) = ∑_{i=0}^k y_i X^i.

(2^{2w+f+f'+2}-1)*(2k-1).
If the multiplication result modulo r is c, then the following holds:

Its bitlength is
x * y = c + z*r.

2w+f+f'+1+log_2(2k-1),
We can check the correctness of the multiplication by checking the following
identity at a random point:

which leads to maximal overflow of
x(X) * y(X) = c(X) + z(X) * r(X) + (2^w' - X) e(X),

w+f+f'+1+log_2(2k-1).
where e(X) is a polynomial used for carrying the overflows of the left- and
right-hand side of the above equation.

# Subtraction

Expand Down
3 changes: 3 additions & 0 deletions std/math/emulated/element.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ type Element[T FieldParams] struct {
// ensure that the limbs are width-constrained. We do not store the
// enforcement info in the Element to prevent modifying the witness.
internal bool

isEvaluated bool
evaluation frontend.Variable `gnark:"-"`
}

// ValueOf returns an Element[T] from a constant value.
Expand Down
15 changes: 15 additions & 0 deletions std/math/emulated/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"sync"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/internal/kvstore"
"github.com/consensys/gnark/internal/utils"
"github.com/consensys/gnark/logger"
"github.com/consensys/gnark/std/rangecheck"
Expand Down Expand Up @@ -42,8 +43,12 @@ type Field[T FieldParams] struct {

constrainedLimbs map[uint64]struct{}
checker frontend.Rangechecker

mulChecks []mulCheck[T]
}

type ctxKey[T FieldParams] struct{}

// NewField returns an object to be used in-circuit to perform emulated
// arithmetic over the field defined by type parameter [FieldParams]. The
// operations on this type are defined on [Element]. There is also another type
Expand All @@ -53,6 +58,12 @@ type Field[T FieldParams] struct {
// This is an experimental feature and performing emulated arithmetic in-circuit
// is extremly costly. See package doc for more info.
func NewField[T FieldParams](native frontend.API) (*Field[T], error) {
if storer, ok := native.(kvstore.Store); ok {
ff := storer.GetKeyValue(ctxKey[T]{})
if ff, ok := ff.(*Field[T]); ok {
return ff, nil
}
}
f := &Field[T]{
api: native,
log: logger.Logger(),
Expand Down Expand Up @@ -89,6 +100,10 @@ func NewField[T FieldParams](native frontend.API) (*Field[T], error) {
return nil, fmt.Errorf("elements with limb length %d does not fit into scalar field", f.fParams.BitsPerLimb())
}

native.Compiler().Defer(f.performMulChecks)
if storer, ok := native.(kvstore.Store); ok {
storer.SetKeyValue(ctxKey[T]{}, f)
}
return f, nil
}

Expand Down
2 changes: 1 addition & 1 deletion std/math/emulated/field_assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func (f *Field[T]) AssertLimbsEquality(a, b *Element[T]) {
// (defined by the field parameter).
func (f *Field[T]) enforceWidth(a *Element[T], modWidth bool) {
if _, aConst := f.constantValue(a); aConst {
if len(a.Limbs) != int(f.fParams.NbLimbs()) {
if modWidth && len(a.Limbs) != int(f.fParams.NbLimbs()) {
panic("constant limb width doesn't match parametrized field")
}
}
Expand Down
Loading