From b85d0e0ca0628fbd9fabd30cd0dae72df29dcf60 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sun, 2 Jul 2023 15:25:56 -0500 Subject: [PATCH] perf: async parallel plonk pr read (#748) --- backend/plonk/bls12-377/marshal.go | 80 +++++++++--- backend/plonk/bls12-377/setup.go | 47 +++++-- backend/plonk/bls12-381/marshal.go | 80 +++++++++--- backend/plonk/bls12-381/setup.go | 47 +++++-- backend/plonk/bls24-315/marshal.go | 80 +++++++++--- backend/plonk/bls24-315/setup.go | 47 +++++-- backend/plonk/bls24-317/marshal.go | 80 +++++++++--- backend/plonk/bls24-317/setup.go | 47 +++++-- backend/plonk/bn254/marshal.go | 80 +++++++++--- backend/plonk/bn254/setup.go | 47 +++++-- backend/plonk/bw6-633/marshal.go | 80 +++++++++--- backend/plonk/bw6-633/setup.go | 47 +++++-- backend/plonk/bw6-761/marshal.go | 80 +++++++++--- backend/plonk/bw6-761/setup.go | 47 +++++-- go.mod | 4 +- go.sum | 4 +- .../zkpschemes/plonk/plonk.marshal.go.tmpl | 81 +++++++++--- .../zkpschemes/plonk/plonk.setup.go.tmpl | 48 +++++-- internal/tinyfield/vector.go | 118 ++++++++++++++++++ internal/tinyfield/vector_test.go | 22 +++- 20 files changed, 952 insertions(+), 214 deletions(-) diff --git a/backend/plonk/bls12-377/marshal.go b/backend/plonk/bls12-377/marshal.go index baf32ebcf4..3d5eebd1d5 100644 --- a/backend/plonk/bls12-377/marshal.go +++ b/backend/plonk/bls12-377/marshal.go @@ -189,13 +189,13 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err return n, err } - n2, err := pk.Domain[0].ReadFrom(r) + n2, err, chDomain0 := pk.Domain[0].AsyncReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.Domain[1].ReadFrom(r) + n2, err, chDomain1 := pk.Domain[1].AsyncReadFrom(r) n += n2 if err != nil { return n, err @@ -217,23 +217,65 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err var ql, qr, qm, qo, qk, lqk, s1, s2, s3 []fr.Element var qcp [][]fr.Element - toDecode := []interface{}{ - &ql, - &qr, - &qm, - &qo, - &qk, - &qcp, - &lqk, - &s1, - &s2, - &s3, - &pk.trace.S, + + // TODO @gbotrel: this is a bit ugly, we should probably refactor this. + // The order of the variables is important, as it matches the order in which they are + // encoded in the WriteTo(...) method. + + // Note: instead of calling dec.Decode(...) for each of the above variables, + // we call AsyncReadFrom when possible which allows to consume bytes from the reader + // and perform the decoding in parallel + + type v struct { + data *fr.Vector + chErr chan error } - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return n + dec.BytesRead(), err + vectors := make([]v, 9) + vectors[0] = v{data: (*fr.Vector)(&ql)} + vectors[1] = v{data: (*fr.Vector)(&qr)} + vectors[2] = v{data: (*fr.Vector)(&qm)} + vectors[3] = v{data: (*fr.Vector)(&qo)} + vectors[4] = v{data: (*fr.Vector)(&qk)} + vectors[5] = v{data: (*fr.Vector)(&lqk)} + vectors[6] = v{data: (*fr.Vector)(&s1)} + vectors[7] = v{data: (*fr.Vector)(&s2)} + vectors[8] = v{data: (*fr.Vector)(&s3)} + + // read ql, qr, qm, qo, qk + for i := 0; i < 5; i++ { + n2, err, ch := vectors[i].data.AsyncReadFrom(r) + n += n2 + if err != nil { + return n, err + } + vectors[i].chErr = ch + } + + // read qcp + if err := dec.Decode(&qcp); err != nil { + return n + dec.BytesRead(), err + } + + // read lqk, s1, s2, s3 + for i := 5; i < 9; i++ { + n2, err, ch := vectors[i].data.AsyncReadFrom(r) + n += n2 + if err != nil { + return n, err + } + vectors[i].chErr = ch + } + + // read pk.Trace.S + if err := dec.Decode(&pk.trace.S); err != nil { + return n + dec.BytesRead(), err + } + + // wait for all AsyncReadFrom(...) to complete + for i := range vectors { + if err := <-vectors[i].chErr; err != nil { + return n, err } } @@ -254,6 +296,10 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} pk.lQk = iop.NewPolynomial(&lqk, lagReg) + // wait for FFT to be precomputed + <-chDomain0 + <-chDomain1 + pk.computeLagrangeCosetPolys() return n + dec.BytesRead(), nil diff --git a/backend/plonk/bls12-377/setup.go b/backend/plonk/bls12-377/setup.go index 273fbd2239..9416161b6d 100644 --- a/backend/plonk/bls12-377/setup.go +++ b/backend/plonk/bls12-377/setup.go @@ -26,6 +26,7 @@ import ( "github.com/consensys/gnark/backend/plonk/internal" "github.com/consensys/gnark/constraint" cs "github.com/consensys/gnark/constraint/bls12-377" + "sync" ) // Trace stores a plonk trace as columns @@ -176,18 +177,44 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro // computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset // basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. func (pk *ProvingKey) computeLagrangeCosetPolys() { + var wg sync.WaitGroup + wg.Add(7 + len(pk.trace.Qcp)) + n1 := int(pk.Domain[1].Cardinality) pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) for i, qcpI := range pk.trace.Qcp { - pk.lcQcp[i] = qcpI.Clone().ToLagrangeCoset(&pk.Domain[1]) + go func(i int, qcpI *iop.Polynomial) { + pk.lcQcp[i] = qcpI.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }(i, qcpI) } - pk.lcQl = pk.trace.Ql.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQr = pk.trace.Qr.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQm = pk.trace.Qm.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQo = pk.trace.Qo.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS1 = pk.trace.S1.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS2 = pk.trace.S2.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS3 = pk.trace.S3.Clone().ToLagrangeCoset(&pk.Domain[1]) - + go func() { + pk.lcQl = pk.trace.Ql.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQr = pk.trace.Qr.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQm = pk.trace.Qm.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQo = pk.trace.Qo.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS1 = pk.trace.S1.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS2 = pk.trace.S2.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS3 = pk.trace.S3.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() // storing Id lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} id := make([]fr.Element, pk.Domain[1].Cardinality) @@ -207,6 +234,8 @@ func (pk *ProvingKey) computeLagrangeCosetPolys() { pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). ToRegular(). ToLagrangeCoset(&pk.Domain[1]) + + wg.Wait() } // NbPublicWitness returns the expected public witness size (number of field elements) diff --git a/backend/plonk/bls12-381/marshal.go b/backend/plonk/bls12-381/marshal.go index 7472b7b16b..cd0b5204a5 100644 --- a/backend/plonk/bls12-381/marshal.go +++ b/backend/plonk/bls12-381/marshal.go @@ -189,13 +189,13 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err return n, err } - n2, err := pk.Domain[0].ReadFrom(r) + n2, err, chDomain0 := pk.Domain[0].AsyncReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.Domain[1].ReadFrom(r) + n2, err, chDomain1 := pk.Domain[1].AsyncReadFrom(r) n += n2 if err != nil { return n, err @@ -217,23 +217,65 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err var ql, qr, qm, qo, qk, lqk, s1, s2, s3 []fr.Element var qcp [][]fr.Element - toDecode := []interface{}{ - &ql, - &qr, - &qm, - &qo, - &qk, - &qcp, - &lqk, - &s1, - &s2, - &s3, - &pk.trace.S, + + // TODO @gbotrel: this is a bit ugly, we should probably refactor this. + // The order of the variables is important, as it matches the order in which they are + // encoded in the WriteTo(...) method. + + // Note: instead of calling dec.Decode(...) for each of the above variables, + // we call AsyncReadFrom when possible which allows to consume bytes from the reader + // and perform the decoding in parallel + + type v struct { + data *fr.Vector + chErr chan error } - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return n + dec.BytesRead(), err + vectors := make([]v, 9) + vectors[0] = v{data: (*fr.Vector)(&ql)} + vectors[1] = v{data: (*fr.Vector)(&qr)} + vectors[2] = v{data: (*fr.Vector)(&qm)} + vectors[3] = v{data: (*fr.Vector)(&qo)} + vectors[4] = v{data: (*fr.Vector)(&qk)} + vectors[5] = v{data: (*fr.Vector)(&lqk)} + vectors[6] = v{data: (*fr.Vector)(&s1)} + vectors[7] = v{data: (*fr.Vector)(&s2)} + vectors[8] = v{data: (*fr.Vector)(&s3)} + + // read ql, qr, qm, qo, qk + for i := 0; i < 5; i++ { + n2, err, ch := vectors[i].data.AsyncReadFrom(r) + n += n2 + if err != nil { + return n, err + } + vectors[i].chErr = ch + } + + // read qcp + if err := dec.Decode(&qcp); err != nil { + return n + dec.BytesRead(), err + } + + // read lqk, s1, s2, s3 + for i := 5; i < 9; i++ { + n2, err, ch := vectors[i].data.AsyncReadFrom(r) + n += n2 + if err != nil { + return n, err + } + vectors[i].chErr = ch + } + + // read pk.Trace.S + if err := dec.Decode(&pk.trace.S); err != nil { + return n + dec.BytesRead(), err + } + + // wait for all AsyncReadFrom(...) to complete + for i := range vectors { + if err := <-vectors[i].chErr; err != nil { + return n, err } } @@ -254,6 +296,10 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} pk.lQk = iop.NewPolynomial(&lqk, lagReg) + // wait for FFT to be precomputed + <-chDomain0 + <-chDomain1 + pk.computeLagrangeCosetPolys() return n + dec.BytesRead(), nil diff --git a/backend/plonk/bls12-381/setup.go b/backend/plonk/bls12-381/setup.go index bc5a5efe22..46b2820ce2 100644 --- a/backend/plonk/bls12-381/setup.go +++ b/backend/plonk/bls12-381/setup.go @@ -26,6 +26,7 @@ import ( "github.com/consensys/gnark/backend/plonk/internal" "github.com/consensys/gnark/constraint" cs "github.com/consensys/gnark/constraint/bls12-381" + "sync" ) // Trace stores a plonk trace as columns @@ -176,18 +177,44 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro // computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset // basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. func (pk *ProvingKey) computeLagrangeCosetPolys() { + var wg sync.WaitGroup + wg.Add(7 + len(pk.trace.Qcp)) + n1 := int(pk.Domain[1].Cardinality) pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) for i, qcpI := range pk.trace.Qcp { - pk.lcQcp[i] = qcpI.Clone().ToLagrangeCoset(&pk.Domain[1]) + go func(i int, qcpI *iop.Polynomial) { + pk.lcQcp[i] = qcpI.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }(i, qcpI) } - pk.lcQl = pk.trace.Ql.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQr = pk.trace.Qr.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQm = pk.trace.Qm.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQo = pk.trace.Qo.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS1 = pk.trace.S1.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS2 = pk.trace.S2.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS3 = pk.trace.S3.Clone().ToLagrangeCoset(&pk.Domain[1]) - + go func() { + pk.lcQl = pk.trace.Ql.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQr = pk.trace.Qr.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQm = pk.trace.Qm.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQo = pk.trace.Qo.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS1 = pk.trace.S1.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS2 = pk.trace.S2.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS3 = pk.trace.S3.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() // storing Id lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} id := make([]fr.Element, pk.Domain[1].Cardinality) @@ -207,6 +234,8 @@ func (pk *ProvingKey) computeLagrangeCosetPolys() { pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). ToRegular(). ToLagrangeCoset(&pk.Domain[1]) + + wg.Wait() } // NbPublicWitness returns the expected public witness size (number of field elements) diff --git a/backend/plonk/bls24-315/marshal.go b/backend/plonk/bls24-315/marshal.go index 607c72c93e..ad1f16dcfe 100644 --- a/backend/plonk/bls24-315/marshal.go +++ b/backend/plonk/bls24-315/marshal.go @@ -189,13 +189,13 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err return n, err } - n2, err := pk.Domain[0].ReadFrom(r) + n2, err, chDomain0 := pk.Domain[0].AsyncReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.Domain[1].ReadFrom(r) + n2, err, chDomain1 := pk.Domain[1].AsyncReadFrom(r) n += n2 if err != nil { return n, err @@ -217,23 +217,65 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err var ql, qr, qm, qo, qk, lqk, s1, s2, s3 []fr.Element var qcp [][]fr.Element - toDecode := []interface{}{ - &ql, - &qr, - &qm, - &qo, - &qk, - &qcp, - &lqk, - &s1, - &s2, - &s3, - &pk.trace.S, + + // TODO @gbotrel: this is a bit ugly, we should probably refactor this. + // The order of the variables is important, as it matches the order in which they are + // encoded in the WriteTo(...) method. + + // Note: instead of calling dec.Decode(...) for each of the above variables, + // we call AsyncReadFrom when possible which allows to consume bytes from the reader + // and perform the decoding in parallel + + type v struct { + data *fr.Vector + chErr chan error } - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return n + dec.BytesRead(), err + vectors := make([]v, 9) + vectors[0] = v{data: (*fr.Vector)(&ql)} + vectors[1] = v{data: (*fr.Vector)(&qr)} + vectors[2] = v{data: (*fr.Vector)(&qm)} + vectors[3] = v{data: (*fr.Vector)(&qo)} + vectors[4] = v{data: (*fr.Vector)(&qk)} + vectors[5] = v{data: (*fr.Vector)(&lqk)} + vectors[6] = v{data: (*fr.Vector)(&s1)} + vectors[7] = v{data: (*fr.Vector)(&s2)} + vectors[8] = v{data: (*fr.Vector)(&s3)} + + // read ql, qr, qm, qo, qk + for i := 0; i < 5; i++ { + n2, err, ch := vectors[i].data.AsyncReadFrom(r) + n += n2 + if err != nil { + return n, err + } + vectors[i].chErr = ch + } + + // read qcp + if err := dec.Decode(&qcp); err != nil { + return n + dec.BytesRead(), err + } + + // read lqk, s1, s2, s3 + for i := 5; i < 9; i++ { + n2, err, ch := vectors[i].data.AsyncReadFrom(r) + n += n2 + if err != nil { + return n, err + } + vectors[i].chErr = ch + } + + // read pk.Trace.S + if err := dec.Decode(&pk.trace.S); err != nil { + return n + dec.BytesRead(), err + } + + // wait for all AsyncReadFrom(...) to complete + for i := range vectors { + if err := <-vectors[i].chErr; err != nil { + return n, err } } @@ -254,6 +296,10 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} pk.lQk = iop.NewPolynomial(&lqk, lagReg) + // wait for FFT to be precomputed + <-chDomain0 + <-chDomain1 + pk.computeLagrangeCosetPolys() return n + dec.BytesRead(), nil diff --git a/backend/plonk/bls24-315/setup.go b/backend/plonk/bls24-315/setup.go index f7d674ed66..16702c5ee1 100644 --- a/backend/plonk/bls24-315/setup.go +++ b/backend/plonk/bls24-315/setup.go @@ -26,6 +26,7 @@ import ( "github.com/consensys/gnark/backend/plonk/internal" "github.com/consensys/gnark/constraint" cs "github.com/consensys/gnark/constraint/bls24-315" + "sync" ) // Trace stores a plonk trace as columns @@ -176,18 +177,44 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro // computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset // basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. func (pk *ProvingKey) computeLagrangeCosetPolys() { + var wg sync.WaitGroup + wg.Add(7 + len(pk.trace.Qcp)) + n1 := int(pk.Domain[1].Cardinality) pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) for i, qcpI := range pk.trace.Qcp { - pk.lcQcp[i] = qcpI.Clone().ToLagrangeCoset(&pk.Domain[1]) + go func(i int, qcpI *iop.Polynomial) { + pk.lcQcp[i] = qcpI.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }(i, qcpI) } - pk.lcQl = pk.trace.Ql.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQr = pk.trace.Qr.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQm = pk.trace.Qm.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQo = pk.trace.Qo.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS1 = pk.trace.S1.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS2 = pk.trace.S2.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS3 = pk.trace.S3.Clone().ToLagrangeCoset(&pk.Domain[1]) - + go func() { + pk.lcQl = pk.trace.Ql.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQr = pk.trace.Qr.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQm = pk.trace.Qm.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQo = pk.trace.Qo.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS1 = pk.trace.S1.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS2 = pk.trace.S2.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS3 = pk.trace.S3.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() // storing Id lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} id := make([]fr.Element, pk.Domain[1].Cardinality) @@ -207,6 +234,8 @@ func (pk *ProvingKey) computeLagrangeCosetPolys() { pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). ToRegular(). ToLagrangeCoset(&pk.Domain[1]) + + wg.Wait() } // NbPublicWitness returns the expected public witness size (number of field elements) diff --git a/backend/plonk/bls24-317/marshal.go b/backend/plonk/bls24-317/marshal.go index 8b1436a444..f3f8280898 100644 --- a/backend/plonk/bls24-317/marshal.go +++ b/backend/plonk/bls24-317/marshal.go @@ -189,13 +189,13 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err return n, err } - n2, err := pk.Domain[0].ReadFrom(r) + n2, err, chDomain0 := pk.Domain[0].AsyncReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.Domain[1].ReadFrom(r) + n2, err, chDomain1 := pk.Domain[1].AsyncReadFrom(r) n += n2 if err != nil { return n, err @@ -217,23 +217,65 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err var ql, qr, qm, qo, qk, lqk, s1, s2, s3 []fr.Element var qcp [][]fr.Element - toDecode := []interface{}{ - &ql, - &qr, - &qm, - &qo, - &qk, - &qcp, - &lqk, - &s1, - &s2, - &s3, - &pk.trace.S, + + // TODO @gbotrel: this is a bit ugly, we should probably refactor this. + // The order of the variables is important, as it matches the order in which they are + // encoded in the WriteTo(...) method. + + // Note: instead of calling dec.Decode(...) for each of the above variables, + // we call AsyncReadFrom when possible which allows to consume bytes from the reader + // and perform the decoding in parallel + + type v struct { + data *fr.Vector + chErr chan error } - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return n + dec.BytesRead(), err + vectors := make([]v, 9) + vectors[0] = v{data: (*fr.Vector)(&ql)} + vectors[1] = v{data: (*fr.Vector)(&qr)} + vectors[2] = v{data: (*fr.Vector)(&qm)} + vectors[3] = v{data: (*fr.Vector)(&qo)} + vectors[4] = v{data: (*fr.Vector)(&qk)} + vectors[5] = v{data: (*fr.Vector)(&lqk)} + vectors[6] = v{data: (*fr.Vector)(&s1)} + vectors[7] = v{data: (*fr.Vector)(&s2)} + vectors[8] = v{data: (*fr.Vector)(&s3)} + + // read ql, qr, qm, qo, qk + for i := 0; i < 5; i++ { + n2, err, ch := vectors[i].data.AsyncReadFrom(r) + n += n2 + if err != nil { + return n, err + } + vectors[i].chErr = ch + } + + // read qcp + if err := dec.Decode(&qcp); err != nil { + return n + dec.BytesRead(), err + } + + // read lqk, s1, s2, s3 + for i := 5; i < 9; i++ { + n2, err, ch := vectors[i].data.AsyncReadFrom(r) + n += n2 + if err != nil { + return n, err + } + vectors[i].chErr = ch + } + + // read pk.Trace.S + if err := dec.Decode(&pk.trace.S); err != nil { + return n + dec.BytesRead(), err + } + + // wait for all AsyncReadFrom(...) to complete + for i := range vectors { + if err := <-vectors[i].chErr; err != nil { + return n, err } } @@ -254,6 +296,10 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} pk.lQk = iop.NewPolynomial(&lqk, lagReg) + // wait for FFT to be precomputed + <-chDomain0 + <-chDomain1 + pk.computeLagrangeCosetPolys() return n + dec.BytesRead(), nil diff --git a/backend/plonk/bls24-317/setup.go b/backend/plonk/bls24-317/setup.go index 77b2e5ac05..60cc0a3020 100644 --- a/backend/plonk/bls24-317/setup.go +++ b/backend/plonk/bls24-317/setup.go @@ -26,6 +26,7 @@ import ( "github.com/consensys/gnark/backend/plonk/internal" "github.com/consensys/gnark/constraint" cs "github.com/consensys/gnark/constraint/bls24-317" + "sync" ) // Trace stores a plonk trace as columns @@ -176,18 +177,44 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro // computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset // basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. func (pk *ProvingKey) computeLagrangeCosetPolys() { + var wg sync.WaitGroup + wg.Add(7 + len(pk.trace.Qcp)) + n1 := int(pk.Domain[1].Cardinality) pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) for i, qcpI := range pk.trace.Qcp { - pk.lcQcp[i] = qcpI.Clone().ToLagrangeCoset(&pk.Domain[1]) + go func(i int, qcpI *iop.Polynomial) { + pk.lcQcp[i] = qcpI.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }(i, qcpI) } - pk.lcQl = pk.trace.Ql.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQr = pk.trace.Qr.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQm = pk.trace.Qm.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQo = pk.trace.Qo.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS1 = pk.trace.S1.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS2 = pk.trace.S2.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS3 = pk.trace.S3.Clone().ToLagrangeCoset(&pk.Domain[1]) - + go func() { + pk.lcQl = pk.trace.Ql.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQr = pk.trace.Qr.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQm = pk.trace.Qm.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQo = pk.trace.Qo.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS1 = pk.trace.S1.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS2 = pk.trace.S2.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS3 = pk.trace.S3.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() // storing Id lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} id := make([]fr.Element, pk.Domain[1].Cardinality) @@ -207,6 +234,8 @@ func (pk *ProvingKey) computeLagrangeCosetPolys() { pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). ToRegular(). ToLagrangeCoset(&pk.Domain[1]) + + wg.Wait() } // NbPublicWitness returns the expected public witness size (number of field elements) diff --git a/backend/plonk/bn254/marshal.go b/backend/plonk/bn254/marshal.go index 38ad70f3a6..5d690ee24d 100644 --- a/backend/plonk/bn254/marshal.go +++ b/backend/plonk/bn254/marshal.go @@ -189,13 +189,13 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err return n, err } - n2, err := pk.Domain[0].ReadFrom(r) + n2, err, chDomain0 := pk.Domain[0].AsyncReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.Domain[1].ReadFrom(r) + n2, err, chDomain1 := pk.Domain[1].AsyncReadFrom(r) n += n2 if err != nil { return n, err @@ -217,23 +217,65 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err var ql, qr, qm, qo, qk, lqk, s1, s2, s3 []fr.Element var qcp [][]fr.Element - toDecode := []interface{}{ - &ql, - &qr, - &qm, - &qo, - &qk, - &qcp, - &lqk, - &s1, - &s2, - &s3, - &pk.trace.S, + + // TODO @gbotrel: this is a bit ugly, we should probably refactor this. + // The order of the variables is important, as it matches the order in which they are + // encoded in the WriteTo(...) method. + + // Note: instead of calling dec.Decode(...) for each of the above variables, + // we call AsyncReadFrom when possible which allows to consume bytes from the reader + // and perform the decoding in parallel + + type v struct { + data *fr.Vector + chErr chan error } - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return n + dec.BytesRead(), err + vectors := make([]v, 9) + vectors[0] = v{data: (*fr.Vector)(&ql)} + vectors[1] = v{data: (*fr.Vector)(&qr)} + vectors[2] = v{data: (*fr.Vector)(&qm)} + vectors[3] = v{data: (*fr.Vector)(&qo)} + vectors[4] = v{data: (*fr.Vector)(&qk)} + vectors[5] = v{data: (*fr.Vector)(&lqk)} + vectors[6] = v{data: (*fr.Vector)(&s1)} + vectors[7] = v{data: (*fr.Vector)(&s2)} + vectors[8] = v{data: (*fr.Vector)(&s3)} + + // read ql, qr, qm, qo, qk + for i := 0; i < 5; i++ { + n2, err, ch := vectors[i].data.AsyncReadFrom(r) + n += n2 + if err != nil { + return n, err + } + vectors[i].chErr = ch + } + + // read qcp + if err := dec.Decode(&qcp); err != nil { + return n + dec.BytesRead(), err + } + + // read lqk, s1, s2, s3 + for i := 5; i < 9; i++ { + n2, err, ch := vectors[i].data.AsyncReadFrom(r) + n += n2 + if err != nil { + return n, err + } + vectors[i].chErr = ch + } + + // read pk.Trace.S + if err := dec.Decode(&pk.trace.S); err != nil { + return n + dec.BytesRead(), err + } + + // wait for all AsyncReadFrom(...) to complete + for i := range vectors { + if err := <-vectors[i].chErr; err != nil { + return n, err } } @@ -254,6 +296,10 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} pk.lQk = iop.NewPolynomial(&lqk, lagReg) + // wait for FFT to be precomputed + <-chDomain0 + <-chDomain1 + pk.computeLagrangeCosetPolys() return n + dec.BytesRead(), nil diff --git a/backend/plonk/bn254/setup.go b/backend/plonk/bn254/setup.go index 66b387dc80..20f7c9a64a 100644 --- a/backend/plonk/bn254/setup.go +++ b/backend/plonk/bn254/setup.go @@ -26,6 +26,7 @@ import ( "github.com/consensys/gnark/backend/plonk/internal" "github.com/consensys/gnark/constraint" cs "github.com/consensys/gnark/constraint/bn254" + "sync" ) // Trace stores a plonk trace as columns @@ -176,18 +177,44 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro // computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset // basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. func (pk *ProvingKey) computeLagrangeCosetPolys() { + var wg sync.WaitGroup + wg.Add(7 + len(pk.trace.Qcp)) + n1 := int(pk.Domain[1].Cardinality) pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) for i, qcpI := range pk.trace.Qcp { - pk.lcQcp[i] = qcpI.Clone().ToLagrangeCoset(&pk.Domain[1]) + go func(i int, qcpI *iop.Polynomial) { + pk.lcQcp[i] = qcpI.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }(i, qcpI) } - pk.lcQl = pk.trace.Ql.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQr = pk.trace.Qr.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQm = pk.trace.Qm.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQo = pk.trace.Qo.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS1 = pk.trace.S1.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS2 = pk.trace.S2.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS3 = pk.trace.S3.Clone().ToLagrangeCoset(&pk.Domain[1]) - + go func() { + pk.lcQl = pk.trace.Ql.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQr = pk.trace.Qr.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQm = pk.trace.Qm.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQo = pk.trace.Qo.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS1 = pk.trace.S1.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS2 = pk.trace.S2.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS3 = pk.trace.S3.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() // storing Id lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} id := make([]fr.Element, pk.Domain[1].Cardinality) @@ -207,6 +234,8 @@ func (pk *ProvingKey) computeLagrangeCosetPolys() { pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). ToRegular(). ToLagrangeCoset(&pk.Domain[1]) + + wg.Wait() } // NbPublicWitness returns the expected public witness size (number of field elements) diff --git a/backend/plonk/bw6-633/marshal.go b/backend/plonk/bw6-633/marshal.go index cfc681f351..7c6a31e4ea 100644 --- a/backend/plonk/bw6-633/marshal.go +++ b/backend/plonk/bw6-633/marshal.go @@ -189,13 +189,13 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err return n, err } - n2, err := pk.Domain[0].ReadFrom(r) + n2, err, chDomain0 := pk.Domain[0].AsyncReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.Domain[1].ReadFrom(r) + n2, err, chDomain1 := pk.Domain[1].AsyncReadFrom(r) n += n2 if err != nil { return n, err @@ -217,23 +217,65 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err var ql, qr, qm, qo, qk, lqk, s1, s2, s3 []fr.Element var qcp [][]fr.Element - toDecode := []interface{}{ - &ql, - &qr, - &qm, - &qo, - &qk, - &qcp, - &lqk, - &s1, - &s2, - &s3, - &pk.trace.S, + + // TODO @gbotrel: this is a bit ugly, we should probably refactor this. + // The order of the variables is important, as it matches the order in which they are + // encoded in the WriteTo(...) method. + + // Note: instead of calling dec.Decode(...) for each of the above variables, + // we call AsyncReadFrom when possible which allows to consume bytes from the reader + // and perform the decoding in parallel + + type v struct { + data *fr.Vector + chErr chan error } - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return n + dec.BytesRead(), err + vectors := make([]v, 9) + vectors[0] = v{data: (*fr.Vector)(&ql)} + vectors[1] = v{data: (*fr.Vector)(&qr)} + vectors[2] = v{data: (*fr.Vector)(&qm)} + vectors[3] = v{data: (*fr.Vector)(&qo)} + vectors[4] = v{data: (*fr.Vector)(&qk)} + vectors[5] = v{data: (*fr.Vector)(&lqk)} + vectors[6] = v{data: (*fr.Vector)(&s1)} + vectors[7] = v{data: (*fr.Vector)(&s2)} + vectors[8] = v{data: (*fr.Vector)(&s3)} + + // read ql, qr, qm, qo, qk + for i := 0; i < 5; i++ { + n2, err, ch := vectors[i].data.AsyncReadFrom(r) + n += n2 + if err != nil { + return n, err + } + vectors[i].chErr = ch + } + + // read qcp + if err := dec.Decode(&qcp); err != nil { + return n + dec.BytesRead(), err + } + + // read lqk, s1, s2, s3 + for i := 5; i < 9; i++ { + n2, err, ch := vectors[i].data.AsyncReadFrom(r) + n += n2 + if err != nil { + return n, err + } + vectors[i].chErr = ch + } + + // read pk.Trace.S + if err := dec.Decode(&pk.trace.S); err != nil { + return n + dec.BytesRead(), err + } + + // wait for all AsyncReadFrom(...) to complete + for i := range vectors { + if err := <-vectors[i].chErr; err != nil { + return n, err } } @@ -254,6 +296,10 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} pk.lQk = iop.NewPolynomial(&lqk, lagReg) + // wait for FFT to be precomputed + <-chDomain0 + <-chDomain1 + pk.computeLagrangeCosetPolys() return n + dec.BytesRead(), nil diff --git a/backend/plonk/bw6-633/setup.go b/backend/plonk/bw6-633/setup.go index 4d68ed0018..7390c3ee77 100644 --- a/backend/plonk/bw6-633/setup.go +++ b/backend/plonk/bw6-633/setup.go @@ -26,6 +26,7 @@ import ( "github.com/consensys/gnark/backend/plonk/internal" "github.com/consensys/gnark/constraint" cs "github.com/consensys/gnark/constraint/bw6-633" + "sync" ) // Trace stores a plonk trace as columns @@ -176,18 +177,44 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro // computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset // basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. func (pk *ProvingKey) computeLagrangeCosetPolys() { + var wg sync.WaitGroup + wg.Add(7 + len(pk.trace.Qcp)) + n1 := int(pk.Domain[1].Cardinality) pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) for i, qcpI := range pk.trace.Qcp { - pk.lcQcp[i] = qcpI.Clone().ToLagrangeCoset(&pk.Domain[1]) + go func(i int, qcpI *iop.Polynomial) { + pk.lcQcp[i] = qcpI.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }(i, qcpI) } - pk.lcQl = pk.trace.Ql.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQr = pk.trace.Qr.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQm = pk.trace.Qm.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQo = pk.trace.Qo.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS1 = pk.trace.S1.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS2 = pk.trace.S2.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS3 = pk.trace.S3.Clone().ToLagrangeCoset(&pk.Domain[1]) - + go func() { + pk.lcQl = pk.trace.Ql.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQr = pk.trace.Qr.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQm = pk.trace.Qm.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQo = pk.trace.Qo.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS1 = pk.trace.S1.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS2 = pk.trace.S2.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS3 = pk.trace.S3.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() // storing Id lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} id := make([]fr.Element, pk.Domain[1].Cardinality) @@ -207,6 +234,8 @@ func (pk *ProvingKey) computeLagrangeCosetPolys() { pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). ToRegular(). ToLagrangeCoset(&pk.Domain[1]) + + wg.Wait() } // NbPublicWitness returns the expected public witness size (number of field elements) diff --git a/backend/plonk/bw6-761/marshal.go b/backend/plonk/bw6-761/marshal.go index 4a9406de50..6fff182303 100644 --- a/backend/plonk/bw6-761/marshal.go +++ b/backend/plonk/bw6-761/marshal.go @@ -189,13 +189,13 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err return n, err } - n2, err := pk.Domain[0].ReadFrom(r) + n2, err, chDomain0 := pk.Domain[0].AsyncReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.Domain[1].ReadFrom(r) + n2, err, chDomain1 := pk.Domain[1].AsyncReadFrom(r) n += n2 if err != nil { return n, err @@ -217,23 +217,65 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err var ql, qr, qm, qo, qk, lqk, s1, s2, s3 []fr.Element var qcp [][]fr.Element - toDecode := []interface{}{ - &ql, - &qr, - &qm, - &qo, - &qk, - &qcp, - &lqk, - &s1, - &s2, - &s3, - &pk.trace.S, + + // TODO @gbotrel: this is a bit ugly, we should probably refactor this. + // The order of the variables is important, as it matches the order in which they are + // encoded in the WriteTo(...) method. + + // Note: instead of calling dec.Decode(...) for each of the above variables, + // we call AsyncReadFrom when possible which allows to consume bytes from the reader + // and perform the decoding in parallel + + type v struct { + data *fr.Vector + chErr chan error } - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return n + dec.BytesRead(), err + vectors := make([]v, 9) + vectors[0] = v{data: (*fr.Vector)(&ql)} + vectors[1] = v{data: (*fr.Vector)(&qr)} + vectors[2] = v{data: (*fr.Vector)(&qm)} + vectors[3] = v{data: (*fr.Vector)(&qo)} + vectors[4] = v{data: (*fr.Vector)(&qk)} + vectors[5] = v{data: (*fr.Vector)(&lqk)} + vectors[6] = v{data: (*fr.Vector)(&s1)} + vectors[7] = v{data: (*fr.Vector)(&s2)} + vectors[8] = v{data: (*fr.Vector)(&s3)} + + // read ql, qr, qm, qo, qk + for i := 0; i < 5; i++ { + n2, err, ch := vectors[i].data.AsyncReadFrom(r) + n += n2 + if err != nil { + return n, err + } + vectors[i].chErr = ch + } + + // read qcp + if err := dec.Decode(&qcp); err != nil { + return n + dec.BytesRead(), err + } + + // read lqk, s1, s2, s3 + for i := 5; i < 9; i++ { + n2, err, ch := vectors[i].data.AsyncReadFrom(r) + n += n2 + if err != nil { + return n, err + } + vectors[i].chErr = ch + } + + // read pk.Trace.S + if err := dec.Decode(&pk.trace.S); err != nil { + return n + dec.BytesRead(), err + } + + // wait for all AsyncReadFrom(...) to complete + for i := range vectors { + if err := <-vectors[i].chErr; err != nil { + return n, err } } @@ -254,6 +296,10 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} pk.lQk = iop.NewPolynomial(&lqk, lagReg) + // wait for FFT to be precomputed + <-chDomain0 + <-chDomain1 + pk.computeLagrangeCosetPolys() return n + dec.BytesRead(), nil diff --git a/backend/plonk/bw6-761/setup.go b/backend/plonk/bw6-761/setup.go index 75751f7929..16c2ca2da0 100644 --- a/backend/plonk/bw6-761/setup.go +++ b/backend/plonk/bw6-761/setup.go @@ -26,6 +26,7 @@ import ( "github.com/consensys/gnark/backend/plonk/internal" "github.com/consensys/gnark/constraint" cs "github.com/consensys/gnark/constraint/bw6-761" + "sync" ) // Trace stores a plonk trace as columns @@ -176,18 +177,44 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro // computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset // basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. func (pk *ProvingKey) computeLagrangeCosetPolys() { + var wg sync.WaitGroup + wg.Add(7 + len(pk.trace.Qcp)) + n1 := int(pk.Domain[1].Cardinality) pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) for i, qcpI := range pk.trace.Qcp { - pk.lcQcp[i] = qcpI.Clone().ToLagrangeCoset(&pk.Domain[1]) + go func(i int, qcpI *iop.Polynomial) { + pk.lcQcp[i] = qcpI.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }(i, qcpI) } - pk.lcQl = pk.trace.Ql.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQr = pk.trace.Qr.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQm = pk.trace.Qm.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQo = pk.trace.Qo.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS1 = pk.trace.S1.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS2 = pk.trace.S2.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS3 = pk.trace.S3.Clone().ToLagrangeCoset(&pk.Domain[1]) - + go func() { + pk.lcQl = pk.trace.Ql.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQr = pk.trace.Qr.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQm = pk.trace.Qm.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQo = pk.trace.Qo.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS1 = pk.trace.S1.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS2 = pk.trace.S2.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS3 = pk.trace.S3.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() // storing Id lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} id := make([]fr.Element, pk.Domain[1].Cardinality) @@ -207,6 +234,8 @@ func (pk *ProvingKey) computeLagrangeCosetPolys() { pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). ToRegular(). ToLagrangeCoset(&pk.Domain[1]) + + wg.Wait() } // NbPublicWitness returns the expected public witness size (number of field elements) diff --git a/go.mod b/go.mod index 1933f6b587..3de47678a5 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/bits-and-blooms/bitset v1.7.0 github.com/blang/semver/v4 v4.0.0 github.com/consensys/bavard v0.1.13 - github.com/consensys/gnark-crypto v0.11.1-0.20230701141209-4dc5ff1b675c + github.com/consensys/gnark-crypto v0.11.1-0.20230702195904-e0bc87ecc0e7 github.com/fxamacker/cbor/v2 v2.4.0 github.com/google/go-cmp v0.5.9 github.com/google/pprof v0.0.0-20230309165930-d61513b1440d @@ -15,7 +15,6 @@ require ( github.com/stretchr/testify v1.8.2 golang.org/x/crypto v0.10.0 golang.org/x/exp v0.0.0-20230213192124-5e25df0256eb - golang.org/x/sys v0.9.0 ) require ( @@ -26,6 +25,7 @@ require ( github.com/mmcloughlin/addchain v0.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/x448/float16 v0.8.4 // indirect + golang.org/x/sys v0.9.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect rsc.io/tmplfunc v0.0.3 // indirect diff --git a/go.sum b/go.sum index b7677d20f6..53968ce087 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,8 @@ github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ= github.com/consensys/bavard v0.1.13 h1:oLhMLOFGTLdlda/kma4VOJazblc7IM5y5QPd2A/YjhQ= github.com/consensys/bavard v0.1.13/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= -github.com/consensys/gnark-crypto v0.11.1-0.20230701141209-4dc5ff1b675c h1:5khtC1xHGIkFsvEFObZcHnghdlBrsSPYdymtp0iV0do= -github.com/consensys/gnark-crypto v0.11.1-0.20230701141209-4dc5ff1b675c/go.mod h1:6C2ytC8zmP8uH2GKVfPOjf0Vw3KwMAaUxlCPK5WQqmw= +github.com/consensys/gnark-crypto v0.11.1-0.20230702195904-e0bc87ecc0e7 h1:Y4eVT+d64VzJx+9osW/lLpdSzTWnahMMrCSKFj0zO6M= +github.com/consensys/gnark-crypto v0.11.1-0.20230702195904-e0bc87ecc0e7/go.mod h1:6C2ytC8zmP8uH2GKVfPOjf0Vw3KwMAaUxlCPK5WQqmw= github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl index ee63bacccb..6639c4434e 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl @@ -169,13 +169,13 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err return n, err } - n2, err := pk.Domain[0].ReadFrom(r) + n2, err, chDomain0 := pk.Domain[0].AsyncReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.Domain[1].ReadFrom(r) + n2, err, chDomain1 := pk.Domain[1].AsyncReadFrom(r) n += n2 if err != nil { return n, err @@ -197,23 +197,65 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err var ql, qr, qm, qo, qk, lqk, s1, s2, s3 []fr.Element var qcp [][]fr.Element - toDecode := []interface{}{ - &ql, - &qr, - &qm, - &qo, - &qk, - &qcp, - &lqk, - &s1, - &s2, - &s3, - &pk.trace.S, + + // TODO @gbotrel: this is a bit ugly, we should probably refactor this. + // The order of the variables is important, as it matches the order in which they are + // encoded in the WriteTo(...) method. + + // Note: instead of calling dec.Decode(...) for each of the above variables, + // we call AsyncReadFrom when possible which allows to consume bytes from the reader + // and perform the decoding in parallel + + type v struct { + data *fr.Vector + chErr chan error } - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return n + dec.BytesRead(), err + vectors := make([]v, 9) + vectors[0] = v{data: (*fr.Vector)(&ql)} + vectors[1] = v{data: (*fr.Vector)(&qr)} + vectors[2] = v{data: (*fr.Vector)(&qm)} + vectors[3] = v{data: (*fr.Vector)(&qo)} + vectors[4] = v{data: (*fr.Vector)(&qk)} + vectors[5] = v{data: (*fr.Vector)(&lqk)} + vectors[6] = v{data: (*fr.Vector)(&s1)} + vectors[7] = v{data: (*fr.Vector)(&s2)} + vectors[8] = v{data: (*fr.Vector)(&s3)} + + // read ql, qr, qm, qo, qk + for i := 0; i < 5; i++ { + n2, err, ch := vectors[i].data.AsyncReadFrom(r) + n += n2 + if err != nil { + return n, err + } + vectors[i].chErr = ch + } + + // read qcp + if err := dec.Decode(&qcp); err != nil { + return n + dec.BytesRead(), err + } + + // read lqk, s1, s2, s3 + for i := 5; i < 9; i++ { + n2, err, ch := vectors[i].data.AsyncReadFrom(r) + n += n2 + if err != nil { + return n, err + } + vectors[i].chErr = ch + } + + // read pk.Trace.S + if err := dec.Decode(&pk.trace.S); err != nil { + return n + dec.BytesRead(), err + } + + // wait for all AsyncReadFrom(...) to complete + for i := range vectors { + if err := <-vectors[i].chErr; err != nil { + return n, err } } @@ -234,6 +276,11 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} pk.lQk = iop.NewPolynomial(&lqk, lagReg) + + // wait for FFT to be precomputed + <-chDomain0 + <-chDomain1 + pk.computeLagrangeCosetPolys() return n + dec.BytesRead(), nil diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl index 0bacab893c..d2f6a0ed60 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl @@ -8,6 +8,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/plonk/internal" "github.com/consensys/gnark/constraint" + "sync" ) // Trace stores a plonk trace as columns @@ -158,18 +159,44 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro // computeLagrangeCosetPolys computes each polynomial except qk in Lagrange coset // basis. Qk will be evaluated in Lagrange coset basis once it is completed by the prover. func (pk *ProvingKey) computeLagrangeCosetPolys() { + var wg sync.WaitGroup + wg.Add(7 + len(pk.trace.Qcp)) + n1 := int(pk.Domain[1].Cardinality) pk.lcQcp = make([]*iop.Polynomial, len(pk.trace.Qcp)) for i, qcpI := range pk.trace.Qcp { - pk.lcQcp[i] = qcpI.Clone().ToLagrangeCoset(&pk.Domain[1]) + go func(i int, qcpI *iop.Polynomial) { + pk.lcQcp[i] = qcpI.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }(i, qcpI) } - pk.lcQl = pk.trace.Ql.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQr = pk.trace.Qr.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQm = pk.trace.Qm.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcQo = pk.trace.Qo.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS1 = pk.trace.S1.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS2 = pk.trace.S2.Clone().ToLagrangeCoset(&pk.Domain[1]) - pk.lcS3 = pk.trace.S3.Clone().ToLagrangeCoset(&pk.Domain[1]) - + go func() { + pk.lcQl = pk.trace.Ql.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQr = pk.trace.Qr.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQm = pk.trace.Qm.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcQo = pk.trace.Qo.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS1 = pk.trace.S1.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS2 = pk.trace.S2.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() + go func() { + pk.lcS3 = pk.trace.S3.Clone(n1).ToLagrangeCoset(&pk.Domain[1]) + wg.Done() + }() // storing Id lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} id := make([]fr.Element, pk.Domain[1].Cardinality) @@ -189,8 +216,11 @@ func (pk *ProvingKey) computeLagrangeCosetPolys() { pk.lLoneIOP = iop.NewPolynomial(&lone, lagReg).ToCanonical(&pk.Domain[0]). ToRegular(). ToLagrangeCoset(&pk.Domain[1]) + + wg.Wait() } + // NbPublicWitness returns the expected public witness size (number of field elements) func (vk *VerifyingKey) NbPublicWitness() int { return int(vk.NbPublicVariables) diff --git a/internal/tinyfield/vector.go b/internal/tinyfield/vector.go index 4f9e9a15e4..9ef47d3cda 100644 --- a/internal/tinyfield/vector.go +++ b/internal/tinyfield/vector.go @@ -19,8 +19,13 @@ package tinyfield import ( "bytes" "encoding/binary" + "fmt" "io" + "runtime" "strings" + "sync" + "sync/atomic" + "unsafe" ) // Vector represents a slice of Element. @@ -73,6 +78,66 @@ func (vector *Vector) WriteTo(w io.Writer) (int64, error) { return n, nil } +// AsyncReadFrom reads a vector of big endian encoded Element. +// Length of the vector must be encoded as a uint32 on the first 4 bytes. +// It consumes the needed bytes from the reader and returns the number of bytes read and an error if any. +// It also returns a channel that will be closed when the validation is done. +// The validation consist of checking that the elements are smaller than the modulus, and +// converting them to montgomery form. +func (vector *Vector) AsyncReadFrom(r io.Reader) (int64, error, chan error) { + chErr := make(chan error, 1) + var buf [Bytes]byte + if read, err := io.ReadFull(r, buf[:4]); err != nil { + close(chErr) + return int64(read), err, chErr + } + sliceLen := binary.BigEndian.Uint32(buf[:4]) + + n := int64(4) + (*vector) = make(Vector, sliceLen) + if sliceLen == 0 { + close(chErr) + return n, nil, chErr + } + + bSlice := unsafe.Slice((*byte)(unsafe.Pointer(&(*vector)[0])), sliceLen*Bytes) + read, err := io.ReadFull(r, bSlice) + n += int64(read) + if err != nil { + close(chErr) + return n, err, chErr + } + + go func() { + var cptErrors uint64 + // process the elements in parallel + execute(int(sliceLen), func(start, end int) { + + var z Element + for i := start; i < end; i++ { + // we have to set vector[i] + bstart := i * Bytes + bend := bstart + Bytes + b := bSlice[bstart:bend] + z[0] = binary.BigEndian.Uint64(b[0:8]) + + if !z.smallerThanModulus() { + atomic.AddUint64(&cptErrors, 1) + return + } + z.toMont() + (*vector)[i] = z + } + }) + + if cptErrors > 0 { + chErr <- fmt.Errorf("async read: %d elements failed validation", cptErrors) + } + close(chErr) + }() + return n, nil, chErr +} + // ReadFrom implements io.ReaderFrom and reads a vector of big endian encoded Element. // Length of the vector must be encoded as a uint32 on the first 4 bytes. func (vector *Vector) ReadFrom(r io.Reader) (int64, error) { @@ -130,3 +195,56 @@ func (vector Vector) Less(i, j int) bool { func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } + +// TODO @gbotrel make a public package out of that. +// execute executes the work function in parallel. +// this is copy paste from internal/parallel/parallel.go +// as we don't want to generate code importing internal/ +func execute(nbIterations int, work func(int, int), maxCpus ...int) { + + nbTasks := runtime.NumCPU() + if len(maxCpus) == 1 { + nbTasks = maxCpus[0] + if nbTasks < 1 { + nbTasks = 1 + } else if nbTasks > 512 { + nbTasks = 512 + } + } + + if nbTasks == 1 { + // no go routines + work(0, nbIterations) + return + } + + nbIterationsPerCpus := nbIterations / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = nbIterations + } + + var wg sync.WaitGroup + + extraTasks := nbIterations - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + go func() { + work(_start, _end) + wg.Done() + }() + } + + wg.Wait() +} diff --git a/internal/tinyfield/vector_test.go b/internal/tinyfield/vector_test.go index e1db416306..68a98e5fa9 100644 --- a/internal/tinyfield/vector_test.go +++ b/internal/tinyfield/vector_test.go @@ -17,6 +17,7 @@ package tinyfield import ( + "bytes" "github.com/stretchr/testify/require" "reflect" "sort" @@ -47,12 +48,16 @@ func TestVectorRoundTrip(t *testing.T) { b, err := v1.MarshalBinary() assert.NoError(err) - var v2 Vector + var v2, v3 Vector err = v2.UnmarshalBinary(b) assert.NoError(err) + err = v3.unmarshalBinaryAsync(b) + assert.NoError(err) + assert.True(reflect.DeepEqual(v1, v2)) + assert.True(reflect.DeepEqual(v3, v2)) } func TestVectorEmptyRoundTrip(t *testing.T) { @@ -63,10 +68,23 @@ func TestVectorEmptyRoundTrip(t *testing.T) { b, err := v1.MarshalBinary() assert.NoError(err) - var v2 Vector + var v2, v3 Vector err = v2.UnmarshalBinary(b) assert.NoError(err) + err = v3.unmarshalBinaryAsync(b) + assert.NoError(err) + assert.True(reflect.DeepEqual(v1, v2)) + assert.True(reflect.DeepEqual(v3, v2)) +} + +func (vector *Vector) unmarshalBinaryAsync(data []byte) error { + r := bytes.NewReader(data) + _, err, chErr := vector.AsyncReadFrom(r) + if err != nil { + return err + } + return <-chErr }