diff --git a/backend/groth16/bls12-377/mpcsetup/marshal_test.go b/backend/groth16/bls12-377/mpcsetup/marshal_test.go index fafeac27ca..ab7e6956cc 100644 --- a/backend/groth16/bls12-377/mpcsetup/marshal_test.go +++ b/backend/groth16/bls12-377/mpcsetup/marshal_test.go @@ -19,7 +19,7 @@ package mpcsetup import ( "bytes" curve "github.com/consensys/gnark-crypto/ecc/bls12-377" - "github.com/consensys/gnark/constraint/bls12-377" + cs "github.com/consensys/gnark/constraint/bls12-377" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/stretchr/testify/require" diff --git a/backend/groth16/bls12-377/mpcsetup/phase2.go b/backend/groth16/bls12-377/mpcsetup/phase2.go index 3460138d12..e3816d65ca 100644 --- a/backend/groth16/bls12-377/mpcsetup/phase2.go +++ b/backend/groth16/bls12-377/mpcsetup/phase2.go @@ -24,7 +24,7 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bls12-377" + cs "github.com/consensys/gnark/constraint/bls12-377" ) type Phase2Evaluations struct { diff --git a/backend/groth16/bls12-377/mpcsetup/setup_test.go b/backend/groth16/bls12-377/mpcsetup/setup_test.go index 6cfb710888..ca8cca346f 100644 --- a/backend/groth16/bls12-377/mpcsetup/setup_test.go +++ b/backend/groth16/bls12-377/mpcsetup/setup_test.go @@ -19,7 +19,7 @@ package mpcsetup import ( curve "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/gnark/constraint/bls12-377" + cs "github.com/consensys/gnark/constraint/bls12-377" "testing" "github.com/consensys/gnark/backend/groth16" diff --git a/backend/groth16/bls12-377/prove.go b/backend/groth16/bls12-377/prove.go index 7f52932504..ed7124a557 100644 --- a/backend/groth16/bls12-377/prove.go +++ b/backend/groth16/bls12-377/prove.go @@ -26,7 +26,7 @@ import ( "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/backend/witness" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bls12-377" + cs "github.com/consensys/gnark/constraint/bls12-377" "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" @@ -94,6 +94,13 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b }(i))) } + if r1cs.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + solverOpts = append(solverOpts, + solver.OverrideHint(r1cs.GkrInfo.SolveHintID, cs.GkrSolveHint(r1cs.GkrInfo, &gkrData)), + solver.OverrideHint(r1cs.GkrInfo.ProveHintID, cs.GkrProveHint(r1cs.GkrInfo.HashName, &gkrData))) + } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) if err != nil { return nil, err diff --git a/backend/groth16/bls12-377/setup.go b/backend/groth16/bls12-377/setup.go index 1cd302a9a9..393d6a802e 100644 --- a/backend/groth16/bls12-377/setup.go +++ b/backend/groth16/bls12-377/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/pedersen" "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bls12-377" + cs "github.com/consensys/gnark/constraint/bls12-377" "math/big" "math/bits" ) diff --git a/backend/groth16/bls12-381/mpcsetup/marshal_test.go b/backend/groth16/bls12-381/mpcsetup/marshal_test.go index 9432d50a6a..bbcaa65d36 100644 --- a/backend/groth16/bls12-381/mpcsetup/marshal_test.go +++ b/backend/groth16/bls12-381/mpcsetup/marshal_test.go @@ -19,7 +19,7 @@ package mpcsetup import ( "bytes" curve "github.com/consensys/gnark-crypto/ecc/bls12-381" - "github.com/consensys/gnark/constraint/bls12-381" + cs "github.com/consensys/gnark/constraint/bls12-381" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/stretchr/testify/require" diff --git a/backend/groth16/bls12-381/mpcsetup/phase2.go b/backend/groth16/bls12-381/mpcsetup/phase2.go index f2cd98bd77..ed42a69f9c 100644 --- a/backend/groth16/bls12-381/mpcsetup/phase2.go +++ b/backend/groth16/bls12-381/mpcsetup/phase2.go @@ -24,7 +24,7 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bls12-381" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bls12-381" + cs "github.com/consensys/gnark/constraint/bls12-381" ) type Phase2Evaluations struct { diff --git a/backend/groth16/bls12-381/mpcsetup/setup_test.go b/backend/groth16/bls12-381/mpcsetup/setup_test.go index e801821bea..0e9880b010 100644 --- a/backend/groth16/bls12-381/mpcsetup/setup_test.go +++ b/backend/groth16/bls12-381/mpcsetup/setup_test.go @@ -19,7 +19,7 @@ package mpcsetup import ( curve "github.com/consensys/gnark-crypto/ecc/bls12-381" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - "github.com/consensys/gnark/constraint/bls12-381" + cs "github.com/consensys/gnark/constraint/bls12-381" "testing" "github.com/consensys/gnark/backend/groth16" diff --git a/backend/groth16/bls12-381/prove.go b/backend/groth16/bls12-381/prove.go index 23ab3529dc..6e4c0a5227 100644 --- a/backend/groth16/bls12-381/prove.go +++ b/backend/groth16/bls12-381/prove.go @@ -26,7 +26,7 @@ import ( "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/backend/witness" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bls12-381" + cs "github.com/consensys/gnark/constraint/bls12-381" "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" @@ -94,6 +94,13 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b }(i))) } + if r1cs.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + solverOpts = append(solverOpts, + solver.OverrideHint(r1cs.GkrInfo.SolveHintID, cs.GkrSolveHint(r1cs.GkrInfo, &gkrData)), + solver.OverrideHint(r1cs.GkrInfo.ProveHintID, cs.GkrProveHint(r1cs.GkrInfo.HashName, &gkrData))) + } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) if err != nil { return nil, err diff --git a/backend/groth16/bls12-381/setup.go b/backend/groth16/bls12-381/setup.go index d1ca2e9b2d..b5333ba374 100644 --- a/backend/groth16/bls12-381/setup.go +++ b/backend/groth16/bls12-381/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/pedersen" "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bls12-381" + cs "github.com/consensys/gnark/constraint/bls12-381" "math/big" "math/bits" ) diff --git a/backend/groth16/bls24-315/mpcsetup/marshal_test.go b/backend/groth16/bls24-315/mpcsetup/marshal_test.go index 9ca32105de..02d61df6a9 100644 --- a/backend/groth16/bls24-315/mpcsetup/marshal_test.go +++ b/backend/groth16/bls24-315/mpcsetup/marshal_test.go @@ -19,7 +19,7 @@ package mpcsetup import ( "bytes" curve "github.com/consensys/gnark-crypto/ecc/bls24-315" - "github.com/consensys/gnark/constraint/bls24-315" + cs "github.com/consensys/gnark/constraint/bls24-315" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/stretchr/testify/require" diff --git a/backend/groth16/bls24-315/mpcsetup/phase2.go b/backend/groth16/bls24-315/mpcsetup/phase2.go index b85efd3856..48939131d3 100644 --- a/backend/groth16/bls24-315/mpcsetup/phase2.go +++ b/backend/groth16/bls24-315/mpcsetup/phase2.go @@ -24,7 +24,7 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bls24-315" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bls24-315" + cs "github.com/consensys/gnark/constraint/bls24-315" ) type Phase2Evaluations struct { diff --git a/backend/groth16/bls24-315/mpcsetup/setup_test.go b/backend/groth16/bls24-315/mpcsetup/setup_test.go index c640467623..25c8affc68 100644 --- a/backend/groth16/bls24-315/mpcsetup/setup_test.go +++ b/backend/groth16/bls24-315/mpcsetup/setup_test.go @@ -19,7 +19,7 @@ package mpcsetup import ( curve "github.com/consensys/gnark-crypto/ecc/bls24-315" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - "github.com/consensys/gnark/constraint/bls24-315" + cs "github.com/consensys/gnark/constraint/bls24-315" "testing" "github.com/consensys/gnark/backend/groth16" diff --git a/backend/groth16/bls24-315/prove.go b/backend/groth16/bls24-315/prove.go index 0b0774af00..c464544ad0 100644 --- a/backend/groth16/bls24-315/prove.go +++ b/backend/groth16/bls24-315/prove.go @@ -26,7 +26,7 @@ import ( "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/backend/witness" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bls24-315" + cs "github.com/consensys/gnark/constraint/bls24-315" "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" @@ -94,6 +94,13 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b }(i))) } + if r1cs.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + solverOpts = append(solverOpts, + solver.OverrideHint(r1cs.GkrInfo.SolveHintID, cs.GkrSolveHint(r1cs.GkrInfo, &gkrData)), + solver.OverrideHint(r1cs.GkrInfo.ProveHintID, cs.GkrProveHint(r1cs.GkrInfo.HashName, &gkrData))) + } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) if err != nil { return nil, err diff --git a/backend/groth16/bls24-315/setup.go b/backend/groth16/bls24-315/setup.go index a37ce828cd..6a8c8e60d2 100644 --- a/backend/groth16/bls24-315/setup.go +++ b/backend/groth16/bls24-315/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/pedersen" "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bls24-315" + cs "github.com/consensys/gnark/constraint/bls24-315" "math/big" "math/bits" ) diff --git a/backend/groth16/bls24-317/mpcsetup/marshal_test.go b/backend/groth16/bls24-317/mpcsetup/marshal_test.go index 6fc34b923c..ddc7229489 100644 --- a/backend/groth16/bls24-317/mpcsetup/marshal_test.go +++ b/backend/groth16/bls24-317/mpcsetup/marshal_test.go @@ -19,7 +19,7 @@ package mpcsetup import ( "bytes" curve "github.com/consensys/gnark-crypto/ecc/bls24-317" - "github.com/consensys/gnark/constraint/bls24-317" + cs "github.com/consensys/gnark/constraint/bls24-317" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/stretchr/testify/require" diff --git a/backend/groth16/bls24-317/mpcsetup/phase2.go b/backend/groth16/bls24-317/mpcsetup/phase2.go index cac7bc08eb..d3037cc3d3 100644 --- a/backend/groth16/bls24-317/mpcsetup/phase2.go +++ b/backend/groth16/bls24-317/mpcsetup/phase2.go @@ -24,7 +24,7 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bls24-317" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bls24-317" + cs "github.com/consensys/gnark/constraint/bls24-317" ) type Phase2Evaluations struct { diff --git a/backend/groth16/bls24-317/mpcsetup/setup_test.go b/backend/groth16/bls24-317/mpcsetup/setup_test.go index 8332441b27..750ab1e0cf 100644 --- a/backend/groth16/bls24-317/mpcsetup/setup_test.go +++ b/backend/groth16/bls24-317/mpcsetup/setup_test.go @@ -19,7 +19,7 @@ package mpcsetup import ( curve "github.com/consensys/gnark-crypto/ecc/bls24-317" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" - "github.com/consensys/gnark/constraint/bls24-317" + cs "github.com/consensys/gnark/constraint/bls24-317" "testing" "github.com/consensys/gnark/backend/groth16" diff --git a/backend/groth16/bls24-317/prove.go b/backend/groth16/bls24-317/prove.go index 72f08ac910..10f38c5f77 100644 --- a/backend/groth16/bls24-317/prove.go +++ b/backend/groth16/bls24-317/prove.go @@ -26,7 +26,7 @@ import ( "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/backend/witness" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bls24-317" + cs "github.com/consensys/gnark/constraint/bls24-317" "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" @@ -94,6 +94,13 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b }(i))) } + if r1cs.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + solverOpts = append(solverOpts, + solver.OverrideHint(r1cs.GkrInfo.SolveHintID, cs.GkrSolveHint(r1cs.GkrInfo, &gkrData)), + solver.OverrideHint(r1cs.GkrInfo.ProveHintID, cs.GkrProveHint(r1cs.GkrInfo.HashName, &gkrData))) + } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) if err != nil { return nil, err diff --git a/backend/groth16/bls24-317/setup.go b/backend/groth16/bls24-317/setup.go index 04156ae8bd..68ee5c8922 100644 --- a/backend/groth16/bls24-317/setup.go +++ b/backend/groth16/bls24-317/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/pedersen" "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bls24-317" + cs "github.com/consensys/gnark/constraint/bls24-317" "math/big" "math/bits" ) diff --git a/backend/groth16/bn254/mpcsetup/marshal_test.go b/backend/groth16/bn254/mpcsetup/marshal_test.go index b56e2bfdce..ce03c11bd4 100644 --- a/backend/groth16/bn254/mpcsetup/marshal_test.go +++ b/backend/groth16/bn254/mpcsetup/marshal_test.go @@ -19,7 +19,7 @@ package mpcsetup import ( "bytes" curve "github.com/consensys/gnark-crypto/ecc/bn254" - "github.com/consensys/gnark/constraint/bn254" + cs "github.com/consensys/gnark/constraint/bn254" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/stretchr/testify/require" diff --git a/backend/groth16/bn254/mpcsetup/phase2.go b/backend/groth16/bn254/mpcsetup/phase2.go index c14e14b509..3fcafb30da 100644 --- a/backend/groth16/bn254/mpcsetup/phase2.go +++ b/backend/groth16/bn254/mpcsetup/phase2.go @@ -24,7 +24,7 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bn254" "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bn254" + cs "github.com/consensys/gnark/constraint/bn254" ) type Phase2Evaluations struct { diff --git a/backend/groth16/bn254/mpcsetup/setup_test.go b/backend/groth16/bn254/mpcsetup/setup_test.go index e6644c664d..63b717cac4 100644 --- a/backend/groth16/bn254/mpcsetup/setup_test.go +++ b/backend/groth16/bn254/mpcsetup/setup_test.go @@ -19,7 +19,7 @@ package mpcsetup import ( curve "github.com/consensys/gnark-crypto/ecc/bn254" "github.com/consensys/gnark-crypto/ecc/bn254/fr" - "github.com/consensys/gnark/constraint/bn254" + cs "github.com/consensys/gnark/constraint/bn254" "testing" "github.com/consensys/gnark/backend/groth16" diff --git a/backend/groth16/bn254/prove.go b/backend/groth16/bn254/prove.go index 75eb17e895..42ec4de8b9 100644 --- a/backend/groth16/bn254/prove.go +++ b/backend/groth16/bn254/prove.go @@ -26,7 +26,7 @@ import ( "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/backend/witness" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bn254" + cs "github.com/consensys/gnark/constraint/bn254" "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" @@ -94,6 +94,13 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b }(i))) } + if r1cs.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + solverOpts = append(solverOpts, + solver.OverrideHint(r1cs.GkrInfo.SolveHintID, cs.GkrSolveHint(r1cs.GkrInfo, &gkrData)), + solver.OverrideHint(r1cs.GkrInfo.ProveHintID, cs.GkrProveHint(r1cs.GkrInfo.HashName, &gkrData))) + } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) if err != nil { return nil, err diff --git a/backend/groth16/bn254/setup.go b/backend/groth16/bn254/setup.go index d6eaf37793..372c723da0 100644 --- a/backend/groth16/bn254/setup.go +++ b/backend/groth16/bn254/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254/fr/pedersen" "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bn254" + cs "github.com/consensys/gnark/constraint/bn254" "math/big" "math/bits" ) diff --git a/backend/groth16/bw6-633/mpcsetup/marshal_test.go b/backend/groth16/bw6-633/mpcsetup/marshal_test.go index 0c04d15934..b9d346f3bd 100644 --- a/backend/groth16/bw6-633/mpcsetup/marshal_test.go +++ b/backend/groth16/bw6-633/mpcsetup/marshal_test.go @@ -19,7 +19,7 @@ package mpcsetup import ( "bytes" curve "github.com/consensys/gnark-crypto/ecc/bw6-633" - "github.com/consensys/gnark/constraint/bw6-633" + cs "github.com/consensys/gnark/constraint/bw6-633" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/stretchr/testify/require" diff --git a/backend/groth16/bw6-633/mpcsetup/phase2.go b/backend/groth16/bw6-633/mpcsetup/phase2.go index e966fd021a..cdf0bb7578 100644 --- a/backend/groth16/bw6-633/mpcsetup/phase2.go +++ b/backend/groth16/bw6-633/mpcsetup/phase2.go @@ -24,7 +24,7 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bw6-633" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bw6-633" + cs "github.com/consensys/gnark/constraint/bw6-633" ) type Phase2Evaluations struct { diff --git a/backend/groth16/bw6-633/mpcsetup/setup_test.go b/backend/groth16/bw6-633/mpcsetup/setup_test.go index 6933f11050..fa51d16fe2 100644 --- a/backend/groth16/bw6-633/mpcsetup/setup_test.go +++ b/backend/groth16/bw6-633/mpcsetup/setup_test.go @@ -19,7 +19,7 @@ package mpcsetup import ( curve "github.com/consensys/gnark-crypto/ecc/bw6-633" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - "github.com/consensys/gnark/constraint/bw6-633" + cs "github.com/consensys/gnark/constraint/bw6-633" "testing" "github.com/consensys/gnark/backend/groth16" diff --git a/backend/groth16/bw6-633/prove.go b/backend/groth16/bw6-633/prove.go index 41002c3f48..b92dbb6943 100644 --- a/backend/groth16/bw6-633/prove.go +++ b/backend/groth16/bw6-633/prove.go @@ -26,7 +26,7 @@ import ( "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/backend/witness" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bw6-633" + cs "github.com/consensys/gnark/constraint/bw6-633" "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" @@ -94,6 +94,13 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b }(i))) } + if r1cs.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + solverOpts = append(solverOpts, + solver.OverrideHint(r1cs.GkrInfo.SolveHintID, cs.GkrSolveHint(r1cs.GkrInfo, &gkrData)), + solver.OverrideHint(r1cs.GkrInfo.ProveHintID, cs.GkrProveHint(r1cs.GkrInfo.HashName, &gkrData))) + } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) if err != nil { return nil, err diff --git a/backend/groth16/bw6-633/setup.go b/backend/groth16/bw6-633/setup.go index 92c73871b7..f168993476 100644 --- a/backend/groth16/bw6-633/setup.go +++ b/backend/groth16/bw6-633/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/pedersen" "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bw6-633" + cs "github.com/consensys/gnark/constraint/bw6-633" "math/big" "math/bits" ) diff --git a/backend/groth16/bw6-761/mpcsetup/marshal_test.go b/backend/groth16/bw6-761/mpcsetup/marshal_test.go index 46913c86b5..cdb362ab70 100644 --- a/backend/groth16/bw6-761/mpcsetup/marshal_test.go +++ b/backend/groth16/bw6-761/mpcsetup/marshal_test.go @@ -19,7 +19,7 @@ package mpcsetup import ( "bytes" curve "github.com/consensys/gnark-crypto/ecc/bw6-761" - "github.com/consensys/gnark/constraint/bw6-761" + cs "github.com/consensys/gnark/constraint/bw6-761" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/stretchr/testify/require" diff --git a/backend/groth16/bw6-761/mpcsetup/phase2.go b/backend/groth16/bw6-761/mpcsetup/phase2.go index df2d2d5be9..cb0c6f9768 100644 --- a/backend/groth16/bw6-761/mpcsetup/phase2.go +++ b/backend/groth16/bw6-761/mpcsetup/phase2.go @@ -24,7 +24,7 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bw6-761" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bw6-761" + cs "github.com/consensys/gnark/constraint/bw6-761" ) type Phase2Evaluations struct { diff --git a/backend/groth16/bw6-761/mpcsetup/setup_test.go b/backend/groth16/bw6-761/mpcsetup/setup_test.go index 28e4c7a768..83994ca73d 100644 --- a/backend/groth16/bw6-761/mpcsetup/setup_test.go +++ b/backend/groth16/bw6-761/mpcsetup/setup_test.go @@ -19,7 +19,7 @@ package mpcsetup import ( curve "github.com/consensys/gnark-crypto/ecc/bw6-761" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - "github.com/consensys/gnark/constraint/bw6-761" + cs "github.com/consensys/gnark/constraint/bw6-761" "testing" "github.com/consensys/gnark/backend/groth16" diff --git a/backend/groth16/bw6-761/prove.go b/backend/groth16/bw6-761/prove.go index 7234bf7d15..3ee6b9ad0f 100644 --- a/backend/groth16/bw6-761/prove.go +++ b/backend/groth16/bw6-761/prove.go @@ -26,7 +26,7 @@ import ( "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/backend/witness" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bw6-761" + cs "github.com/consensys/gnark/constraint/bw6-761" "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" @@ -94,6 +94,13 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b }(i))) } + if r1cs.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + solverOpts = append(solverOpts, + solver.OverrideHint(r1cs.GkrInfo.SolveHintID, cs.GkrSolveHint(r1cs.GkrInfo, &gkrData)), + solver.OverrideHint(r1cs.GkrInfo.ProveHintID, cs.GkrProveHint(r1cs.GkrInfo.HashName, &gkrData))) + } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) if err != nil { return nil, err diff --git a/backend/groth16/bw6-761/setup.go b/backend/groth16/bw6-761/setup.go index 5a69e9799a..b0fa2811e6 100644 --- a/backend/groth16/bw6-761/setup.go +++ b/backend/groth16/bw6-761/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/pedersen" "github.com/consensys/gnark/backend/groth16/internal" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bw6-761" + cs "github.com/consensys/gnark/constraint/bw6-761" "math/big" "math/bits" ) diff --git a/backend/plonk/bls12-377/prove.go b/backend/plonk/bls12-377/prove.go index 43e999fd87..3f767215ba 100644 --- a/backend/plonk/bls12-377/prove.go +++ b/backend/plonk/bls12-377/prove.go @@ -34,7 +34,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/iop" - "github.com/consensys/gnark/constraint/bls12-377" + cs "github.com/consensys/gnark/constraint/bls12-377" "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark/backend" @@ -127,6 +127,13 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts bsb22ComputeCommitmentHint(spr, pk, proof, cCommitments, &commitmentVal[i], i))) } + if spr.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + opt.SolverOpts = append(opt.SolverOpts, + solver.OverrideHint(spr.GkrInfo.SolveHintID, cs.GkrSolveHint(spr.GkrInfo, &gkrData)), + solver.OverrideHint(spr.GkrInfo.ProveHintID, cs.GkrProveHint(spr.GkrInfo.HashName, &gkrData))) + } + // query l, r, o in Lagrange basis, not blinded _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) if err != nil { diff --git a/backend/plonk/bls12-377/setup.go b/backend/plonk/bls12-377/setup.go index 20d21a0d29..273fbd2239 100644 --- a/backend/plonk/bls12-377/setup.go +++ b/backend/plonk/bls12-377/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/kzg" "github.com/consensys/gnark/backend/plonk/internal" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bls12-377" + cs "github.com/consensys/gnark/constraint/bls12-377" ) // Trace stores a plonk trace as columns diff --git a/backend/plonk/bls12-381/prove.go b/backend/plonk/bls12-381/prove.go index 4397656393..17627cd176 100644 --- a/backend/plonk/bls12-381/prove.go +++ b/backend/plonk/bls12-381/prove.go @@ -34,7 +34,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/iop" - "github.com/consensys/gnark/constraint/bls12-381" + cs "github.com/consensys/gnark/constraint/bls12-381" "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark/backend" @@ -127,6 +127,13 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts bsb22ComputeCommitmentHint(spr, pk, proof, cCommitments, &commitmentVal[i], i))) } + if spr.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + opt.SolverOpts = append(opt.SolverOpts, + solver.OverrideHint(spr.GkrInfo.SolveHintID, cs.GkrSolveHint(spr.GkrInfo, &gkrData)), + solver.OverrideHint(spr.GkrInfo.ProveHintID, cs.GkrProveHint(spr.GkrInfo.HashName, &gkrData))) + } + // query l, r, o in Lagrange basis, not blinded _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) if err != nil { diff --git a/backend/plonk/bls12-381/setup.go b/backend/plonk/bls12-381/setup.go index b3f00a335b..bc5a5efe22 100644 --- a/backend/plonk/bls12-381/setup.go +++ b/backend/plonk/bls12-381/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/kzg" "github.com/consensys/gnark/backend/plonk/internal" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bls12-381" + cs "github.com/consensys/gnark/constraint/bls12-381" ) // Trace stores a plonk trace as columns diff --git a/backend/plonk/bls24-315/prove.go b/backend/plonk/bls24-315/prove.go index 777bd2b8e1..7ed6145ebb 100644 --- a/backend/plonk/bls24-315/prove.go +++ b/backend/plonk/bls24-315/prove.go @@ -34,7 +34,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/iop" - "github.com/consensys/gnark/constraint/bls24-315" + cs "github.com/consensys/gnark/constraint/bls24-315" "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark/backend" @@ -127,6 +127,13 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts bsb22ComputeCommitmentHint(spr, pk, proof, cCommitments, &commitmentVal[i], i))) } + if spr.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + opt.SolverOpts = append(opt.SolverOpts, + solver.OverrideHint(spr.GkrInfo.SolveHintID, cs.GkrSolveHint(spr.GkrInfo, &gkrData)), + solver.OverrideHint(spr.GkrInfo.ProveHintID, cs.GkrProveHint(spr.GkrInfo.HashName, &gkrData))) + } + // query l, r, o in Lagrange basis, not blinded _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) if err != nil { diff --git a/backend/plonk/bls24-315/setup.go b/backend/plonk/bls24-315/setup.go index 41fd1ee0a3..f7d674ed66 100644 --- a/backend/plonk/bls24-315/setup.go +++ b/backend/plonk/bls24-315/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/kzg" "github.com/consensys/gnark/backend/plonk/internal" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bls24-315" + cs "github.com/consensys/gnark/constraint/bls24-315" ) // Trace stores a plonk trace as columns diff --git a/backend/plonk/bls24-317/prove.go b/backend/plonk/bls24-317/prove.go index ae3d6a28dd..0cdc878d8d 100644 --- a/backend/plonk/bls24-317/prove.go +++ b/backend/plonk/bls24-317/prove.go @@ -34,7 +34,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/iop" - "github.com/consensys/gnark/constraint/bls24-317" + cs "github.com/consensys/gnark/constraint/bls24-317" "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark/backend" @@ -127,6 +127,13 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts bsb22ComputeCommitmentHint(spr, pk, proof, cCommitments, &commitmentVal[i], i))) } + if spr.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + opt.SolverOpts = append(opt.SolverOpts, + solver.OverrideHint(spr.GkrInfo.SolveHintID, cs.GkrSolveHint(spr.GkrInfo, &gkrData)), + solver.OverrideHint(spr.GkrInfo.ProveHintID, cs.GkrProveHint(spr.GkrInfo.HashName, &gkrData))) + } + // query l, r, o in Lagrange basis, not blinded _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) if err != nil { diff --git a/backend/plonk/bls24-317/setup.go b/backend/plonk/bls24-317/setup.go index de56d191cd..77b2e5ac05 100644 --- a/backend/plonk/bls24-317/setup.go +++ b/backend/plonk/bls24-317/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/kzg" "github.com/consensys/gnark/backend/plonk/internal" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bls24-317" + cs "github.com/consensys/gnark/constraint/bls24-317" ) // Trace stores a plonk trace as columns diff --git a/backend/plonk/bn254/prove.go b/backend/plonk/bn254/prove.go index 70baf95687..40b120a21f 100644 --- a/backend/plonk/bn254/prove.go +++ b/backend/plonk/bn254/prove.go @@ -34,7 +34,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" "github.com/consensys/gnark-crypto/ecc/bn254/fr/iop" - "github.com/consensys/gnark/constraint/bn254" + cs "github.com/consensys/gnark/constraint/bn254" "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark/backend" @@ -127,6 +127,13 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts bsb22ComputeCommitmentHint(spr, pk, proof, cCommitments, &commitmentVal[i], i))) } + if spr.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + opt.SolverOpts = append(opt.SolverOpts, + solver.OverrideHint(spr.GkrInfo.SolveHintID, cs.GkrSolveHint(spr.GkrInfo, &gkrData)), + solver.OverrideHint(spr.GkrInfo.ProveHintID, cs.GkrProveHint(spr.GkrInfo.HashName, &gkrData))) + } + // query l, r, o in Lagrange basis, not blinded _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) if err != nil { diff --git a/backend/plonk/bn254/setup.go b/backend/plonk/bn254/setup.go index a4e65bce21..66b387dc80 100644 --- a/backend/plonk/bn254/setup.go +++ b/backend/plonk/bn254/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254/fr/kzg" "github.com/consensys/gnark/backend/plonk/internal" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bn254" + cs "github.com/consensys/gnark/constraint/bn254" ) // Trace stores a plonk trace as columns diff --git a/backend/plonk/bw6-633/prove.go b/backend/plonk/bw6-633/prove.go index f5ffbaba57..1ea6ccf02d 100644 --- a/backend/plonk/bw6-633/prove.go +++ b/backend/plonk/bw6-633/prove.go @@ -34,7 +34,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/iop" - "github.com/consensys/gnark/constraint/bw6-633" + cs "github.com/consensys/gnark/constraint/bw6-633" "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark/backend" @@ -127,6 +127,13 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts bsb22ComputeCommitmentHint(spr, pk, proof, cCommitments, &commitmentVal[i], i))) } + if spr.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + opt.SolverOpts = append(opt.SolverOpts, + solver.OverrideHint(spr.GkrInfo.SolveHintID, cs.GkrSolveHint(spr.GkrInfo, &gkrData)), + solver.OverrideHint(spr.GkrInfo.ProveHintID, cs.GkrProveHint(spr.GkrInfo.HashName, &gkrData))) + } + // query l, r, o in Lagrange basis, not blinded _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) if err != nil { diff --git a/backend/plonk/bw6-633/setup.go b/backend/plonk/bw6-633/setup.go index 9848de6fa7..4d68ed0018 100644 --- a/backend/plonk/bw6-633/setup.go +++ b/backend/plonk/bw6-633/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/kzg" "github.com/consensys/gnark/backend/plonk/internal" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bw6-633" + cs "github.com/consensys/gnark/constraint/bw6-633" ) // Trace stores a plonk trace as columns diff --git a/backend/plonk/bw6-761/prove.go b/backend/plonk/bw6-761/prove.go index 54a179a33b..603a95cf31 100644 --- a/backend/plonk/bw6-761/prove.go +++ b/backend/plonk/bw6-761/prove.go @@ -34,7 +34,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/iop" - "github.com/consensys/gnark/constraint/bw6-761" + cs "github.com/consensys/gnark/constraint/bw6-761" "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark/backend" @@ -127,6 +127,13 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts bsb22ComputeCommitmentHint(spr, pk, proof, cCommitments, &commitmentVal[i], i))) } + if spr.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + opt.SolverOpts = append(opt.SolverOpts, + solver.OverrideHint(spr.GkrInfo.SolveHintID, cs.GkrSolveHint(spr.GkrInfo, &gkrData)), + solver.OverrideHint(spr.GkrInfo.ProveHintID, cs.GkrProveHint(spr.GkrInfo.HashName, &gkrData))) + } + // query l, r, o in Lagrange basis, not blinded _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) if err != nil { diff --git a/backend/plonk/bw6-761/setup.go b/backend/plonk/bw6-761/setup.go index f15643bf7b..75751f7929 100644 --- a/backend/plonk/bw6-761/setup.go +++ b/backend/plonk/bw6-761/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/kzg" "github.com/consensys/gnark/backend/plonk/internal" "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/constraint/bw6-761" + cs "github.com/consensys/gnark/constraint/bw6-761" ) // Trace stores a plonk trace as columns diff --git a/backend/plonkfri/bls12-377/prove.go b/backend/plonkfri/bls12-377/prove.go index a525c90cb5..97d6c0301e 100644 --- a/backend/plonkfri/bls12-377/prove.go +++ b/backend/plonkfri/bls12-377/prove.go @@ -28,7 +28,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" - "github.com/consensys/gnark/constraint/bls12-377" + cs "github.com/consensys/gnark/constraint/bls12-377" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fri" diff --git a/backend/plonkfri/bls12-377/setup.go b/backend/plonkfri/bls12-377/setup.go index 2168bbeeab..f4e5ad6daf 100644 --- a/backend/plonkfri/bls12-377/setup.go +++ b/backend/plonkfri/bls12-377/setup.go @@ -21,7 +21,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fri" - "github.com/consensys/gnark/constraint/bls12-377" + cs "github.com/consensys/gnark/constraint/bls12-377" ) // ProvingKey stores the data needed to generate a proof: diff --git a/backend/plonkfri/bls12-381/prove.go b/backend/plonkfri/bls12-381/prove.go index 3823ed3b26..d121b59a37 100644 --- a/backend/plonkfri/bls12-381/prove.go +++ b/backend/plonkfri/bls12-381/prove.go @@ -28,7 +28,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" - "github.com/consensys/gnark/constraint/bls12-381" + cs "github.com/consensys/gnark/constraint/bls12-381" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fri" diff --git a/backend/plonkfri/bls12-381/setup.go b/backend/plonkfri/bls12-381/setup.go index 4f2b6e1b20..11f574f5f2 100644 --- a/backend/plonkfri/bls12-381/setup.go +++ b/backend/plonkfri/bls12-381/setup.go @@ -21,7 +21,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fri" - "github.com/consensys/gnark/constraint/bls12-381" + cs "github.com/consensys/gnark/constraint/bls12-381" ) // ProvingKey stores the data needed to generate a proof: diff --git a/backend/plonkfri/bls24-315/prove.go b/backend/plonkfri/bls24-315/prove.go index 5d7f385ea1..cb1f43fcee 100644 --- a/backend/plonkfri/bls24-315/prove.go +++ b/backend/plonkfri/bls24-315/prove.go @@ -28,7 +28,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" - "github.com/consensys/gnark/constraint/bls24-315" + cs "github.com/consensys/gnark/constraint/bls24-315" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fri" diff --git a/backend/plonkfri/bls24-315/setup.go b/backend/plonkfri/bls24-315/setup.go index 9063dc95d9..c3c0837f90 100644 --- a/backend/plonkfri/bls24-315/setup.go +++ b/backend/plonkfri/bls24-315/setup.go @@ -21,7 +21,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fri" - "github.com/consensys/gnark/constraint/bls24-315" + cs "github.com/consensys/gnark/constraint/bls24-315" ) // ProvingKey stores the data needed to generate a proof: diff --git a/backend/plonkfri/bls24-317/prove.go b/backend/plonkfri/bls24-317/prove.go index 409d66b34e..5fc6cbf713 100644 --- a/backend/plonkfri/bls24-317/prove.go +++ b/backend/plonkfri/bls24-317/prove.go @@ -28,7 +28,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" - "github.com/consensys/gnark/constraint/bls24-317" + cs "github.com/consensys/gnark/constraint/bls24-317" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fri" diff --git a/backend/plonkfri/bls24-317/setup.go b/backend/plonkfri/bls24-317/setup.go index c2ab7ee5be..54668b98f8 100644 --- a/backend/plonkfri/bls24-317/setup.go +++ b/backend/plonkfri/bls24-317/setup.go @@ -21,7 +21,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fri" - "github.com/consensys/gnark/constraint/bls24-317" + cs "github.com/consensys/gnark/constraint/bls24-317" ) // ProvingKey stores the data needed to generate a proof: diff --git a/backend/plonkfri/bn254/prove.go b/backend/plonkfri/bn254/prove.go index 1cb0e6fbc6..161ad667f4 100644 --- a/backend/plonkfri/bn254/prove.go +++ b/backend/plonkfri/bn254/prove.go @@ -28,7 +28,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" - "github.com/consensys/gnark/constraint/bn254" + cs "github.com/consensys/gnark/constraint/bn254" "github.com/consensys/gnark-crypto/ecc/bn254/fr/fri" diff --git a/backend/plonkfri/bn254/setup.go b/backend/plonkfri/bn254/setup.go index 6b37f37600..9f69648ed1 100644 --- a/backend/plonkfri/bn254/setup.go +++ b/backend/plonkfri/bn254/setup.go @@ -21,7 +21,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" "github.com/consensys/gnark-crypto/ecc/bn254/fr/fri" - "github.com/consensys/gnark/constraint/bn254" + cs "github.com/consensys/gnark/constraint/bn254" ) // ProvingKey stores the data needed to generate a proof: diff --git a/backend/plonkfri/bw6-633/prove.go b/backend/plonkfri/bw6-633/prove.go index 3e9d8e8ef4..e71df6e7aa 100644 --- a/backend/plonkfri/bw6-633/prove.go +++ b/backend/plonkfri/bw6-633/prove.go @@ -28,7 +28,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" - "github.com/consensys/gnark/constraint/bw6-633" + cs "github.com/consensys/gnark/constraint/bw6-633" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fri" diff --git a/backend/plonkfri/bw6-633/setup.go b/backend/plonkfri/bw6-633/setup.go index dff84dbee9..1384eed0bd 100644 --- a/backend/plonkfri/bw6-633/setup.go +++ b/backend/plonkfri/bw6-633/setup.go @@ -21,7 +21,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fri" - "github.com/consensys/gnark/constraint/bw6-633" + cs "github.com/consensys/gnark/constraint/bw6-633" ) // ProvingKey stores the data needed to generate a proof: diff --git a/backend/plonkfri/bw6-761/prove.go b/backend/plonkfri/bw6-761/prove.go index f57bc4f6d3..9092580485 100644 --- a/backend/plonkfri/bw6-761/prove.go +++ b/backend/plonkfri/bw6-761/prove.go @@ -28,7 +28,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" - "github.com/consensys/gnark/constraint/bw6-761" + cs "github.com/consensys/gnark/constraint/bw6-761" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fri" diff --git a/backend/plonkfri/bw6-761/setup.go b/backend/plonkfri/bw6-761/setup.go index ee278a5eda..a08a528d80 100644 --- a/backend/plonkfri/bw6-761/setup.go +++ b/backend/plonkfri/bw6-761/setup.go @@ -21,7 +21,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fri" - "github.com/consensys/gnark/constraint/bw6-761" + cs "github.com/consensys/gnark/constraint/bw6-761" ) // ProvingKey stores the data needed to generate a proof: diff --git a/constraint/bls12-377/gkr.go b/constraint/bls12-377/gkr.go new file mode 100644 index 0000000000..31dc279700 --- /dev/null +++ b/constraint/bls12-377/gkr.go @@ -0,0 +1,264 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/gkr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" + hint "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/utils/algo_utils" + "hash" + "math/big" +) + +type GkrSolvingData struct { + assignments gkr.WireAssignment + circuit gkr.Circuit + memoryPool polynomial.Pool + workers *utils.WorkerPool +} + +func convertCircuit(noPtr constraint.GkrCircuit) gkr.Circuit { + resCircuit := make(gkr.Circuit, len(noPtr)) + for i := range noPtr { + resCircuit[i].Gate = GkrGateRegistry[noPtr[i].Gate] + resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) + } + return resCircuit +} + +func (d *GkrSolvingData) init(info constraint.GkrInfo) gkrAssignment { + d.circuit = convertCircuit(info.Circuit) + d.memoryPool = polynomial.NewPool(d.circuit.MemoryRequirements(info.NbInstances)...) + d.workers = utils.NewWorkerPool() + + assignmentsSequential := make(gkrAssignment, len(d.circuit)) + d.assignments = make(gkr.WireAssignment, len(d.circuit)) + for i := range assignmentsSequential { + assignmentsSequential[i] = d.memoryPool.Make(info.NbInstances) + d.assignments[&d.circuit[i]] = assignmentsSequential[i] + } + + return assignmentsSequential +} + +func (d *GkrSolvingData) dumpAssignments() { + for _, p := range d.assignments { + d.memoryPool.Dump(p) + } +} + +// this module assumes that wire and instance indexes respect dependencies + +type gkrAssignment [][]fr.Element //gkrAssignment is indexed wire first, instance second + +func (a gkrAssignment) setOuts(circuit constraint.GkrCircuit, outs []*big.Int) { + outsI := 0 + for i := range circuit { + if circuit[i].IsOutput() { + for j := range a[i] { + a[i][j].BigInt(outs[outsI]) + outsI++ + } + } + } + // Check if outsI == len(outs)? +} + +func GkrSolveHint(info constraint.GkrInfo, solvingData *GkrSolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + // assumes assignmentVector is arranged wire first, instance second in order of solution + circuit := info.Circuit + nbInstances := info.NbInstances + offsets := info.AssignmentOffsets() + assignment := solvingData.init(info) + chunks := circuit.Chunks(nbInstances) + + solveTask := func(chunkOffset int) utils.Task { + return func(startInChunk, endInChunk int) { + start := startInChunk + chunkOffset + end := endInChunk + chunkOffset + inputs := solvingData.memoryPool.Make(info.MaxNIns) + dependencyHeads := make([]int, len(circuit)) + for wI, w := range circuit { + dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { + return w.Dependencies[i].InputInstance + }, len(w.Dependencies), start) + } + + for instanceI := start; instanceI < end; instanceI++ { + for wireI, wire := range circuit { + if wire.IsInput() { + if dependencyHeads[wireI] < len(wire.Dependencies) && instanceI == wire.Dependencies[dependencyHeads[wireI]].InputInstance { + dep := wire.Dependencies[dependencyHeads[wireI]] + assignment[wireI][instanceI].Set(&assignment[dep.OutputWire][dep.OutputInstance]) + dependencyHeads[wireI]++ + } else { + assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) + } + } else { + // assemble the inputs + inputIndexes := info.Circuit[wireI].Inputs + for i, inputI := range inputIndexes { + inputs[i].Set(&assignment[inputI][instanceI]) + } + gate := solvingData.circuit[wireI].Gate + assignment[wireI][instanceI] = gate.Evaluate(inputs[:len(inputIndexes)]...) + } + } + } + solvingData.memoryPool.Dump(inputs) + } + } + + start := 0 + for _, end := range chunks { + solvingData.workers.Submit(end-start, solveTask(start), 1024).Wait() + start = end + } + + assignment.setOuts(info.Circuit, outs) + + return nil + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} + +func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { + + return func(_ *big.Int, ins, outs []*big.Int) error { + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + b := i.Bytes() + return b[:] + }) + + hsh := HashBuilderRegistry[hashName]() + + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + if err != nil { + return err + } + + // serialize proof: TODO: In gnark-crypto? + offset := 0 + for i := range proof { + for _, poly := range proof[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if proof[i].FinalEvalProof != nil { + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } + + data.dumpAssignments() + + return nil + + } +} + +var GkrGateRegistry = map[string]gkr.Gate{ // TODO: Migrate to gnark-crypto + "mul": mulGate(2), + "add": addGate{}, + "sub": subGate{}, + "neg": negGate{}, +} + +// TODO: Move to gnark-crypto +var HashBuilderRegistry = make(map[string]func() hash.Hash) + +type mulGate int +type addGate struct{} +type subGate struct{} +type negGate struct{} + +func (g mulGate) Evaluate(x ...fr.Element) (res fr.Element) { + if len(x) != int(g) { + panic("wrong input count") + } + switch len(x) { + case 0: + res.SetOne() + case 1: + res.Set(&x[0]) + default: + res.Mul(&x[0], &x[1]) + for i := 2; i < len(x); i++ { + res.Mul(&res, &x[2]) + } + } + return +} + +func (g mulGate) Degree() int { + return int(g) +} + +func (g addGate) Evaluate(x ...fr.Element) (res fr.Element) { + switch len(x) { + case 0: + // set zero + case 1: + res.Set(&x[0]) + case 2: + res.Add(&x[0], &x[1]) + for i := 2; i < len(x); i++ { + res.Add(&res, &x[2]) + } + } + return +} + +func (g addGate) Degree() int { + return 1 +} + +func (g subGate) Evaluate(element ...fr.Element) (diff fr.Element) { + if len(element) > 2 { + panic("not implemented") //TODO + } + diff.Sub(&element[0], &element[1]) + return +} + +func (g subGate) Degree() int { + return 1 +} + +func (g negGate) Evaluate(element ...fr.Element) (neg fr.Element) { + if len(element) != 1 { + panic("univariate gate") + } + neg.Neg(&element[0]) + return +} + +func (g negGate) Degree() int { + return 1 +} diff --git a/constraint/bls12-377/r1cs_test.go b/constraint/bls12-377/r1cs_test.go index b3eb222420..044c55d31a 100644 --- a/constraint/bls12-377/r1cs_test.go +++ b/constraint/bls12-377/r1cs_test.go @@ -28,7 +28,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/consensys/gnark/constraint/bls12-377" + cs "github.com/consensys/gnark/constraint/bls12-377" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" ) diff --git a/constraint/bls12-377/system.go b/constraint/bls12-377/system.go index 5a35f10a00..49ba1525ef 100644 --- a/constraint/bls12-377/system.go +++ b/constraint/bls12-377/system.go @@ -372,3 +372,7 @@ func getTagSet() cbor.TagSet { return ts } + +func (s *system) AddGkr(gkr constraint.GkrInfo) error { + return s.System.AddGkr(gkr) +} diff --git a/constraint/bls12-381/gkr.go b/constraint/bls12-381/gkr.go new file mode 100644 index 0000000000..5adfe3ec50 --- /dev/null +++ b/constraint/bls12-381/gkr.go @@ -0,0 +1,264 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/gkr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" + hint "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/utils/algo_utils" + "hash" + "math/big" +) + +type GkrSolvingData struct { + assignments gkr.WireAssignment + circuit gkr.Circuit + memoryPool polynomial.Pool + workers *utils.WorkerPool +} + +func convertCircuit(noPtr constraint.GkrCircuit) gkr.Circuit { + resCircuit := make(gkr.Circuit, len(noPtr)) + for i := range noPtr { + resCircuit[i].Gate = GkrGateRegistry[noPtr[i].Gate] + resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) + } + return resCircuit +} + +func (d *GkrSolvingData) init(info constraint.GkrInfo) gkrAssignment { + d.circuit = convertCircuit(info.Circuit) + d.memoryPool = polynomial.NewPool(d.circuit.MemoryRequirements(info.NbInstances)...) + d.workers = utils.NewWorkerPool() + + assignmentsSequential := make(gkrAssignment, len(d.circuit)) + d.assignments = make(gkr.WireAssignment, len(d.circuit)) + for i := range assignmentsSequential { + assignmentsSequential[i] = d.memoryPool.Make(info.NbInstances) + d.assignments[&d.circuit[i]] = assignmentsSequential[i] + } + + return assignmentsSequential +} + +func (d *GkrSolvingData) dumpAssignments() { + for _, p := range d.assignments { + d.memoryPool.Dump(p) + } +} + +// this module assumes that wire and instance indexes respect dependencies + +type gkrAssignment [][]fr.Element //gkrAssignment is indexed wire first, instance second + +func (a gkrAssignment) setOuts(circuit constraint.GkrCircuit, outs []*big.Int) { + outsI := 0 + for i := range circuit { + if circuit[i].IsOutput() { + for j := range a[i] { + a[i][j].BigInt(outs[outsI]) + outsI++ + } + } + } + // Check if outsI == len(outs)? +} + +func GkrSolveHint(info constraint.GkrInfo, solvingData *GkrSolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + // assumes assignmentVector is arranged wire first, instance second in order of solution + circuit := info.Circuit + nbInstances := info.NbInstances + offsets := info.AssignmentOffsets() + assignment := solvingData.init(info) + chunks := circuit.Chunks(nbInstances) + + solveTask := func(chunkOffset int) utils.Task { + return func(startInChunk, endInChunk int) { + start := startInChunk + chunkOffset + end := endInChunk + chunkOffset + inputs := solvingData.memoryPool.Make(info.MaxNIns) + dependencyHeads := make([]int, len(circuit)) + for wI, w := range circuit { + dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { + return w.Dependencies[i].InputInstance + }, len(w.Dependencies), start) + } + + for instanceI := start; instanceI < end; instanceI++ { + for wireI, wire := range circuit { + if wire.IsInput() { + if dependencyHeads[wireI] < len(wire.Dependencies) && instanceI == wire.Dependencies[dependencyHeads[wireI]].InputInstance { + dep := wire.Dependencies[dependencyHeads[wireI]] + assignment[wireI][instanceI].Set(&assignment[dep.OutputWire][dep.OutputInstance]) + dependencyHeads[wireI]++ + } else { + assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) + } + } else { + // assemble the inputs + inputIndexes := info.Circuit[wireI].Inputs + for i, inputI := range inputIndexes { + inputs[i].Set(&assignment[inputI][instanceI]) + } + gate := solvingData.circuit[wireI].Gate + assignment[wireI][instanceI] = gate.Evaluate(inputs[:len(inputIndexes)]...) + } + } + } + solvingData.memoryPool.Dump(inputs) + } + } + + start := 0 + for _, end := range chunks { + solvingData.workers.Submit(end-start, solveTask(start), 1024).Wait() + start = end + } + + assignment.setOuts(info.Circuit, outs) + + return nil + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} + +func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { + + return func(_ *big.Int, ins, outs []*big.Int) error { + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + b := i.Bytes() + return b[:] + }) + + hsh := HashBuilderRegistry[hashName]() + + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + if err != nil { + return err + } + + // serialize proof: TODO: In gnark-crypto? + offset := 0 + for i := range proof { + for _, poly := range proof[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if proof[i].FinalEvalProof != nil { + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } + + data.dumpAssignments() + + return nil + + } +} + +var GkrGateRegistry = map[string]gkr.Gate{ // TODO: Migrate to gnark-crypto + "mul": mulGate(2), + "add": addGate{}, + "sub": subGate{}, + "neg": negGate{}, +} + +// TODO: Move to gnark-crypto +var HashBuilderRegistry = make(map[string]func() hash.Hash) + +type mulGate int +type addGate struct{} +type subGate struct{} +type negGate struct{} + +func (g mulGate) Evaluate(x ...fr.Element) (res fr.Element) { + if len(x) != int(g) { + panic("wrong input count") + } + switch len(x) { + case 0: + res.SetOne() + case 1: + res.Set(&x[0]) + default: + res.Mul(&x[0], &x[1]) + for i := 2; i < len(x); i++ { + res.Mul(&res, &x[2]) + } + } + return +} + +func (g mulGate) Degree() int { + return int(g) +} + +func (g addGate) Evaluate(x ...fr.Element) (res fr.Element) { + switch len(x) { + case 0: + // set zero + case 1: + res.Set(&x[0]) + case 2: + res.Add(&x[0], &x[1]) + for i := 2; i < len(x); i++ { + res.Add(&res, &x[2]) + } + } + return +} + +func (g addGate) Degree() int { + return 1 +} + +func (g subGate) Evaluate(element ...fr.Element) (diff fr.Element) { + if len(element) > 2 { + panic("not implemented") //TODO + } + diff.Sub(&element[0], &element[1]) + return +} + +func (g subGate) Degree() int { + return 1 +} + +func (g negGate) Evaluate(element ...fr.Element) (neg fr.Element) { + if len(element) != 1 { + panic("univariate gate") + } + neg.Neg(&element[0]) + return +} + +func (g negGate) Degree() int { + return 1 +} diff --git a/constraint/bls12-381/r1cs_test.go b/constraint/bls12-381/r1cs_test.go index c7b0fa6156..28f77b956b 100644 --- a/constraint/bls12-381/r1cs_test.go +++ b/constraint/bls12-381/r1cs_test.go @@ -28,7 +28,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/consensys/gnark/constraint/bls12-381" + cs "github.com/consensys/gnark/constraint/bls12-381" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" ) diff --git a/constraint/bls12-381/system.go b/constraint/bls12-381/system.go index ff3c074f67..5bcb7d44ee 100644 --- a/constraint/bls12-381/system.go +++ b/constraint/bls12-381/system.go @@ -372,3 +372,7 @@ func getTagSet() cbor.TagSet { return ts } + +func (s *system) AddGkr(gkr constraint.GkrInfo) error { + return s.System.AddGkr(gkr) +} diff --git a/constraint/bls24-315/gkr.go b/constraint/bls24-315/gkr.go new file mode 100644 index 0000000000..048f8fa911 --- /dev/null +++ b/constraint/bls24-315/gkr.go @@ -0,0 +1,264 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/gkr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" + hint "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/utils/algo_utils" + "hash" + "math/big" +) + +type GkrSolvingData struct { + assignments gkr.WireAssignment + circuit gkr.Circuit + memoryPool polynomial.Pool + workers *utils.WorkerPool +} + +func convertCircuit(noPtr constraint.GkrCircuit) gkr.Circuit { + resCircuit := make(gkr.Circuit, len(noPtr)) + for i := range noPtr { + resCircuit[i].Gate = GkrGateRegistry[noPtr[i].Gate] + resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) + } + return resCircuit +} + +func (d *GkrSolvingData) init(info constraint.GkrInfo) gkrAssignment { + d.circuit = convertCircuit(info.Circuit) + d.memoryPool = polynomial.NewPool(d.circuit.MemoryRequirements(info.NbInstances)...) + d.workers = utils.NewWorkerPool() + + assignmentsSequential := make(gkrAssignment, len(d.circuit)) + d.assignments = make(gkr.WireAssignment, len(d.circuit)) + for i := range assignmentsSequential { + assignmentsSequential[i] = d.memoryPool.Make(info.NbInstances) + d.assignments[&d.circuit[i]] = assignmentsSequential[i] + } + + return assignmentsSequential +} + +func (d *GkrSolvingData) dumpAssignments() { + for _, p := range d.assignments { + d.memoryPool.Dump(p) + } +} + +// this module assumes that wire and instance indexes respect dependencies + +type gkrAssignment [][]fr.Element //gkrAssignment is indexed wire first, instance second + +func (a gkrAssignment) setOuts(circuit constraint.GkrCircuit, outs []*big.Int) { + outsI := 0 + for i := range circuit { + if circuit[i].IsOutput() { + for j := range a[i] { + a[i][j].BigInt(outs[outsI]) + outsI++ + } + } + } + // Check if outsI == len(outs)? +} + +func GkrSolveHint(info constraint.GkrInfo, solvingData *GkrSolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + // assumes assignmentVector is arranged wire first, instance second in order of solution + circuit := info.Circuit + nbInstances := info.NbInstances + offsets := info.AssignmentOffsets() + assignment := solvingData.init(info) + chunks := circuit.Chunks(nbInstances) + + solveTask := func(chunkOffset int) utils.Task { + return func(startInChunk, endInChunk int) { + start := startInChunk + chunkOffset + end := endInChunk + chunkOffset + inputs := solvingData.memoryPool.Make(info.MaxNIns) + dependencyHeads := make([]int, len(circuit)) + for wI, w := range circuit { + dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { + return w.Dependencies[i].InputInstance + }, len(w.Dependencies), start) + } + + for instanceI := start; instanceI < end; instanceI++ { + for wireI, wire := range circuit { + if wire.IsInput() { + if dependencyHeads[wireI] < len(wire.Dependencies) && instanceI == wire.Dependencies[dependencyHeads[wireI]].InputInstance { + dep := wire.Dependencies[dependencyHeads[wireI]] + assignment[wireI][instanceI].Set(&assignment[dep.OutputWire][dep.OutputInstance]) + dependencyHeads[wireI]++ + } else { + assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) + } + } else { + // assemble the inputs + inputIndexes := info.Circuit[wireI].Inputs + for i, inputI := range inputIndexes { + inputs[i].Set(&assignment[inputI][instanceI]) + } + gate := solvingData.circuit[wireI].Gate + assignment[wireI][instanceI] = gate.Evaluate(inputs[:len(inputIndexes)]...) + } + } + } + solvingData.memoryPool.Dump(inputs) + } + } + + start := 0 + for _, end := range chunks { + solvingData.workers.Submit(end-start, solveTask(start), 1024).Wait() + start = end + } + + assignment.setOuts(info.Circuit, outs) + + return nil + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} + +func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { + + return func(_ *big.Int, ins, outs []*big.Int) error { + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + b := i.Bytes() + return b[:] + }) + + hsh := HashBuilderRegistry[hashName]() + + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + if err != nil { + return err + } + + // serialize proof: TODO: In gnark-crypto? + offset := 0 + for i := range proof { + for _, poly := range proof[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if proof[i].FinalEvalProof != nil { + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } + + data.dumpAssignments() + + return nil + + } +} + +var GkrGateRegistry = map[string]gkr.Gate{ // TODO: Migrate to gnark-crypto + "mul": mulGate(2), + "add": addGate{}, + "sub": subGate{}, + "neg": negGate{}, +} + +// TODO: Move to gnark-crypto +var HashBuilderRegistry = make(map[string]func() hash.Hash) + +type mulGate int +type addGate struct{} +type subGate struct{} +type negGate struct{} + +func (g mulGate) Evaluate(x ...fr.Element) (res fr.Element) { + if len(x) != int(g) { + panic("wrong input count") + } + switch len(x) { + case 0: + res.SetOne() + case 1: + res.Set(&x[0]) + default: + res.Mul(&x[0], &x[1]) + for i := 2; i < len(x); i++ { + res.Mul(&res, &x[2]) + } + } + return +} + +func (g mulGate) Degree() int { + return int(g) +} + +func (g addGate) Evaluate(x ...fr.Element) (res fr.Element) { + switch len(x) { + case 0: + // set zero + case 1: + res.Set(&x[0]) + case 2: + res.Add(&x[0], &x[1]) + for i := 2; i < len(x); i++ { + res.Add(&res, &x[2]) + } + } + return +} + +func (g addGate) Degree() int { + return 1 +} + +func (g subGate) Evaluate(element ...fr.Element) (diff fr.Element) { + if len(element) > 2 { + panic("not implemented") //TODO + } + diff.Sub(&element[0], &element[1]) + return +} + +func (g subGate) Degree() int { + return 1 +} + +func (g negGate) Evaluate(element ...fr.Element) (neg fr.Element) { + if len(element) != 1 { + panic("univariate gate") + } + neg.Neg(&element[0]) + return +} + +func (g negGate) Degree() int { + return 1 +} diff --git a/constraint/bls24-315/r1cs_test.go b/constraint/bls24-315/r1cs_test.go index 3885c438b5..4c42f78ee5 100644 --- a/constraint/bls24-315/r1cs_test.go +++ b/constraint/bls24-315/r1cs_test.go @@ -28,7 +28,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/consensys/gnark/constraint/bls24-315" + cs "github.com/consensys/gnark/constraint/bls24-315" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" ) diff --git a/constraint/bls24-315/system.go b/constraint/bls24-315/system.go index de1f206275..01ea069cb3 100644 --- a/constraint/bls24-315/system.go +++ b/constraint/bls24-315/system.go @@ -372,3 +372,7 @@ func getTagSet() cbor.TagSet { return ts } + +func (s *system) AddGkr(gkr constraint.GkrInfo) error { + return s.System.AddGkr(gkr) +} diff --git a/constraint/bls24-317/gkr.go b/constraint/bls24-317/gkr.go new file mode 100644 index 0000000000..f81157a0cf --- /dev/null +++ b/constraint/bls24-317/gkr.go @@ -0,0 +1,264 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/gkr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" + hint "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/utils/algo_utils" + "hash" + "math/big" +) + +type GkrSolvingData struct { + assignments gkr.WireAssignment + circuit gkr.Circuit + memoryPool polynomial.Pool + workers *utils.WorkerPool +} + +func convertCircuit(noPtr constraint.GkrCircuit) gkr.Circuit { + resCircuit := make(gkr.Circuit, len(noPtr)) + for i := range noPtr { + resCircuit[i].Gate = GkrGateRegistry[noPtr[i].Gate] + resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) + } + return resCircuit +} + +func (d *GkrSolvingData) init(info constraint.GkrInfo) gkrAssignment { + d.circuit = convertCircuit(info.Circuit) + d.memoryPool = polynomial.NewPool(d.circuit.MemoryRequirements(info.NbInstances)...) + d.workers = utils.NewWorkerPool() + + assignmentsSequential := make(gkrAssignment, len(d.circuit)) + d.assignments = make(gkr.WireAssignment, len(d.circuit)) + for i := range assignmentsSequential { + assignmentsSequential[i] = d.memoryPool.Make(info.NbInstances) + d.assignments[&d.circuit[i]] = assignmentsSequential[i] + } + + return assignmentsSequential +} + +func (d *GkrSolvingData) dumpAssignments() { + for _, p := range d.assignments { + d.memoryPool.Dump(p) + } +} + +// this module assumes that wire and instance indexes respect dependencies + +type gkrAssignment [][]fr.Element //gkrAssignment is indexed wire first, instance second + +func (a gkrAssignment) setOuts(circuit constraint.GkrCircuit, outs []*big.Int) { + outsI := 0 + for i := range circuit { + if circuit[i].IsOutput() { + for j := range a[i] { + a[i][j].BigInt(outs[outsI]) + outsI++ + } + } + } + // Check if outsI == len(outs)? +} + +func GkrSolveHint(info constraint.GkrInfo, solvingData *GkrSolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + // assumes assignmentVector is arranged wire first, instance second in order of solution + circuit := info.Circuit + nbInstances := info.NbInstances + offsets := info.AssignmentOffsets() + assignment := solvingData.init(info) + chunks := circuit.Chunks(nbInstances) + + solveTask := func(chunkOffset int) utils.Task { + return func(startInChunk, endInChunk int) { + start := startInChunk + chunkOffset + end := endInChunk + chunkOffset + inputs := solvingData.memoryPool.Make(info.MaxNIns) + dependencyHeads := make([]int, len(circuit)) + for wI, w := range circuit { + dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { + return w.Dependencies[i].InputInstance + }, len(w.Dependencies), start) + } + + for instanceI := start; instanceI < end; instanceI++ { + for wireI, wire := range circuit { + if wire.IsInput() { + if dependencyHeads[wireI] < len(wire.Dependencies) && instanceI == wire.Dependencies[dependencyHeads[wireI]].InputInstance { + dep := wire.Dependencies[dependencyHeads[wireI]] + assignment[wireI][instanceI].Set(&assignment[dep.OutputWire][dep.OutputInstance]) + dependencyHeads[wireI]++ + } else { + assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) + } + } else { + // assemble the inputs + inputIndexes := info.Circuit[wireI].Inputs + for i, inputI := range inputIndexes { + inputs[i].Set(&assignment[inputI][instanceI]) + } + gate := solvingData.circuit[wireI].Gate + assignment[wireI][instanceI] = gate.Evaluate(inputs[:len(inputIndexes)]...) + } + } + } + solvingData.memoryPool.Dump(inputs) + } + } + + start := 0 + for _, end := range chunks { + solvingData.workers.Submit(end-start, solveTask(start), 1024).Wait() + start = end + } + + assignment.setOuts(info.Circuit, outs) + + return nil + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} + +func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { + + return func(_ *big.Int, ins, outs []*big.Int) error { + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + b := i.Bytes() + return b[:] + }) + + hsh := HashBuilderRegistry[hashName]() + + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + if err != nil { + return err + } + + // serialize proof: TODO: In gnark-crypto? + offset := 0 + for i := range proof { + for _, poly := range proof[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if proof[i].FinalEvalProof != nil { + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } + + data.dumpAssignments() + + return nil + + } +} + +var GkrGateRegistry = map[string]gkr.Gate{ // TODO: Migrate to gnark-crypto + "mul": mulGate(2), + "add": addGate{}, + "sub": subGate{}, + "neg": negGate{}, +} + +// TODO: Move to gnark-crypto +var HashBuilderRegistry = make(map[string]func() hash.Hash) + +type mulGate int +type addGate struct{} +type subGate struct{} +type negGate struct{} + +func (g mulGate) Evaluate(x ...fr.Element) (res fr.Element) { + if len(x) != int(g) { + panic("wrong input count") + } + switch len(x) { + case 0: + res.SetOne() + case 1: + res.Set(&x[0]) + default: + res.Mul(&x[0], &x[1]) + for i := 2; i < len(x); i++ { + res.Mul(&res, &x[2]) + } + } + return +} + +func (g mulGate) Degree() int { + return int(g) +} + +func (g addGate) Evaluate(x ...fr.Element) (res fr.Element) { + switch len(x) { + case 0: + // set zero + case 1: + res.Set(&x[0]) + case 2: + res.Add(&x[0], &x[1]) + for i := 2; i < len(x); i++ { + res.Add(&res, &x[2]) + } + } + return +} + +func (g addGate) Degree() int { + return 1 +} + +func (g subGate) Evaluate(element ...fr.Element) (diff fr.Element) { + if len(element) > 2 { + panic("not implemented") //TODO + } + diff.Sub(&element[0], &element[1]) + return +} + +func (g subGate) Degree() int { + return 1 +} + +func (g negGate) Evaluate(element ...fr.Element) (neg fr.Element) { + if len(element) != 1 { + panic("univariate gate") + } + neg.Neg(&element[0]) + return +} + +func (g negGate) Degree() int { + return 1 +} diff --git a/constraint/bls24-317/r1cs_test.go b/constraint/bls24-317/r1cs_test.go index 9ad8d2f586..40bb573a0b 100644 --- a/constraint/bls24-317/r1cs_test.go +++ b/constraint/bls24-317/r1cs_test.go @@ -28,7 +28,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/consensys/gnark/constraint/bls24-317" + cs "github.com/consensys/gnark/constraint/bls24-317" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" ) diff --git a/constraint/bls24-317/system.go b/constraint/bls24-317/system.go index f8ac490de2..06633648ce 100644 --- a/constraint/bls24-317/system.go +++ b/constraint/bls24-317/system.go @@ -372,3 +372,7 @@ func getTagSet() cbor.TagSet { return ts } + +func (s *system) AddGkr(gkr constraint.GkrInfo) error { + return s.System.AddGkr(gkr) +} diff --git a/constraint/blueprint.go b/constraint/blueprint.go index 1471d9d3e9..2949f0665f 100644 --- a/constraint/blueprint.go +++ b/constraint/blueprint.go @@ -2,7 +2,7 @@ package constraint type BlueprintID uint32 -// Blueprint enable representing heterogenous constraints or instructions in a constraint system +// Blueprint enable representing heterogeneous constraints or instructions in a constraint system // in a memory efficient way. Blueprints essentially help the frontend/ to "compress" // constraints or instructions, and specify for the solving (or zksnark setup) part how to // "decompress" and optionally "solve" the associated wires. @@ -66,8 +66,8 @@ type BlueprintHint interface { DecompressHint(h *HintMapping, instruction Instruction) } -// Compressable represent an object that knows how to encode itself as a []uint32. -type Compressable interface { +// Compressible represent an object that knows how to encode itself as a []uint32. +type Compressible interface { // Compress interprets the objects as a LinearExpression and encodes it as a []uint32. Compress(to *[]uint32) } diff --git a/constraint/bn254/gkr.go b/constraint/bn254/gkr.go new file mode 100644 index 0000000000..d834f7caca --- /dev/null +++ b/constraint/bn254/gkr.go @@ -0,0 +1,264 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/gkr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" + hint "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/utils/algo_utils" + "hash" + "math/big" +) + +type GkrSolvingData struct { + assignments gkr.WireAssignment + circuit gkr.Circuit + memoryPool polynomial.Pool + workers *utils.WorkerPool +} + +func convertCircuit(noPtr constraint.GkrCircuit) gkr.Circuit { + resCircuit := make(gkr.Circuit, len(noPtr)) + for i := range noPtr { + resCircuit[i].Gate = GkrGateRegistry[noPtr[i].Gate] + resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) + } + return resCircuit +} + +func (d *GkrSolvingData) init(info constraint.GkrInfo) gkrAssignment { + d.circuit = convertCircuit(info.Circuit) + d.memoryPool = polynomial.NewPool(d.circuit.MemoryRequirements(info.NbInstances)...) + d.workers = utils.NewWorkerPool() + + assignmentsSequential := make(gkrAssignment, len(d.circuit)) + d.assignments = make(gkr.WireAssignment, len(d.circuit)) + for i := range assignmentsSequential { + assignmentsSequential[i] = d.memoryPool.Make(info.NbInstances) + d.assignments[&d.circuit[i]] = assignmentsSequential[i] + } + + return assignmentsSequential +} + +func (d *GkrSolvingData) dumpAssignments() { + for _, p := range d.assignments { + d.memoryPool.Dump(p) + } +} + +// this module assumes that wire and instance indexes respect dependencies + +type gkrAssignment [][]fr.Element //gkrAssignment is indexed wire first, instance second + +func (a gkrAssignment) setOuts(circuit constraint.GkrCircuit, outs []*big.Int) { + outsI := 0 + for i := range circuit { + if circuit[i].IsOutput() { + for j := range a[i] { + a[i][j].BigInt(outs[outsI]) + outsI++ + } + } + } + // Check if outsI == len(outs)? +} + +func GkrSolveHint(info constraint.GkrInfo, solvingData *GkrSolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + // assumes assignmentVector is arranged wire first, instance second in order of solution + circuit := info.Circuit + nbInstances := info.NbInstances + offsets := info.AssignmentOffsets() + assignment := solvingData.init(info) + chunks := circuit.Chunks(nbInstances) + + solveTask := func(chunkOffset int) utils.Task { + return func(startInChunk, endInChunk int) { + start := startInChunk + chunkOffset + end := endInChunk + chunkOffset + inputs := solvingData.memoryPool.Make(info.MaxNIns) + dependencyHeads := make([]int, len(circuit)) + for wI, w := range circuit { + dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { + return w.Dependencies[i].InputInstance + }, len(w.Dependencies), start) + } + + for instanceI := start; instanceI < end; instanceI++ { + for wireI, wire := range circuit { + if wire.IsInput() { + if dependencyHeads[wireI] < len(wire.Dependencies) && instanceI == wire.Dependencies[dependencyHeads[wireI]].InputInstance { + dep := wire.Dependencies[dependencyHeads[wireI]] + assignment[wireI][instanceI].Set(&assignment[dep.OutputWire][dep.OutputInstance]) + dependencyHeads[wireI]++ + } else { + assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) + } + } else { + // assemble the inputs + inputIndexes := info.Circuit[wireI].Inputs + for i, inputI := range inputIndexes { + inputs[i].Set(&assignment[inputI][instanceI]) + } + gate := solvingData.circuit[wireI].Gate + assignment[wireI][instanceI] = gate.Evaluate(inputs[:len(inputIndexes)]...) + } + } + } + solvingData.memoryPool.Dump(inputs) + } + } + + start := 0 + for _, end := range chunks { + solvingData.workers.Submit(end-start, solveTask(start), 1024).Wait() + start = end + } + + assignment.setOuts(info.Circuit, outs) + + return nil + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} + +func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { + + return func(_ *big.Int, ins, outs []*big.Int) error { + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + b := i.Bytes() + return b[:] + }) + + hsh := HashBuilderRegistry[hashName]() + + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + if err != nil { + return err + } + + // serialize proof: TODO: In gnark-crypto? + offset := 0 + for i := range proof { + for _, poly := range proof[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if proof[i].FinalEvalProof != nil { + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } + + data.dumpAssignments() + + return nil + + } +} + +var GkrGateRegistry = map[string]gkr.Gate{ // TODO: Migrate to gnark-crypto + "mul": mulGate(2), + "add": addGate{}, + "sub": subGate{}, + "neg": negGate{}, +} + +// TODO: Move to gnark-crypto +var HashBuilderRegistry = make(map[string]func() hash.Hash) + +type mulGate int +type addGate struct{} +type subGate struct{} +type negGate struct{} + +func (g mulGate) Evaluate(x ...fr.Element) (res fr.Element) { + if len(x) != int(g) { + panic("wrong input count") + } + switch len(x) { + case 0: + res.SetOne() + case 1: + res.Set(&x[0]) + default: + res.Mul(&x[0], &x[1]) + for i := 2; i < len(x); i++ { + res.Mul(&res, &x[2]) + } + } + return +} + +func (g mulGate) Degree() int { + return int(g) +} + +func (g addGate) Evaluate(x ...fr.Element) (res fr.Element) { + switch len(x) { + case 0: + // set zero + case 1: + res.Set(&x[0]) + case 2: + res.Add(&x[0], &x[1]) + for i := 2; i < len(x); i++ { + res.Add(&res, &x[2]) + } + } + return +} + +func (g addGate) Degree() int { + return 1 +} + +func (g subGate) Evaluate(element ...fr.Element) (diff fr.Element) { + if len(element) > 2 { + panic("not implemented") //TODO + } + diff.Sub(&element[0], &element[1]) + return +} + +func (g subGate) Degree() int { + return 1 +} + +func (g negGate) Evaluate(element ...fr.Element) (neg fr.Element) { + if len(element) != 1 { + panic("univariate gate") + } + neg.Neg(&element[0]) + return +} + +func (g negGate) Degree() int { + return 1 +} diff --git a/constraint/bn254/r1cs_test.go b/constraint/bn254/r1cs_test.go index 8c8ba5da95..d295603ebf 100644 --- a/constraint/bn254/r1cs_test.go +++ b/constraint/bn254/r1cs_test.go @@ -28,7 +28,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/consensys/gnark/constraint/bn254" + cs "github.com/consensys/gnark/constraint/bn254" "github.com/consensys/gnark-crypto/ecc/bn254/fr" ) diff --git a/constraint/bn254/system.go b/constraint/bn254/system.go index 949e532dbe..a96fe5d5a2 100644 --- a/constraint/bn254/system.go +++ b/constraint/bn254/system.go @@ -372,3 +372,7 @@ func getTagSet() cbor.TagSet { return ts } + +func (s *system) AddGkr(gkr constraint.GkrInfo) error { + return s.System.AddGkr(gkr) +} diff --git a/constraint/bw6-633/gkr.go b/constraint/bw6-633/gkr.go new file mode 100644 index 0000000000..fd3feedb08 --- /dev/null +++ b/constraint/bw6-633/gkr.go @@ -0,0 +1,264 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/gkr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" + hint "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/utils/algo_utils" + "hash" + "math/big" +) + +type GkrSolvingData struct { + assignments gkr.WireAssignment + circuit gkr.Circuit + memoryPool polynomial.Pool + workers *utils.WorkerPool +} + +func convertCircuit(noPtr constraint.GkrCircuit) gkr.Circuit { + resCircuit := make(gkr.Circuit, len(noPtr)) + for i := range noPtr { + resCircuit[i].Gate = GkrGateRegistry[noPtr[i].Gate] + resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) + } + return resCircuit +} + +func (d *GkrSolvingData) init(info constraint.GkrInfo) gkrAssignment { + d.circuit = convertCircuit(info.Circuit) + d.memoryPool = polynomial.NewPool(d.circuit.MemoryRequirements(info.NbInstances)...) + d.workers = utils.NewWorkerPool() + + assignmentsSequential := make(gkrAssignment, len(d.circuit)) + d.assignments = make(gkr.WireAssignment, len(d.circuit)) + for i := range assignmentsSequential { + assignmentsSequential[i] = d.memoryPool.Make(info.NbInstances) + d.assignments[&d.circuit[i]] = assignmentsSequential[i] + } + + return assignmentsSequential +} + +func (d *GkrSolvingData) dumpAssignments() { + for _, p := range d.assignments { + d.memoryPool.Dump(p) + } +} + +// this module assumes that wire and instance indexes respect dependencies + +type gkrAssignment [][]fr.Element //gkrAssignment is indexed wire first, instance second + +func (a gkrAssignment) setOuts(circuit constraint.GkrCircuit, outs []*big.Int) { + outsI := 0 + for i := range circuit { + if circuit[i].IsOutput() { + for j := range a[i] { + a[i][j].BigInt(outs[outsI]) + outsI++ + } + } + } + // Check if outsI == len(outs)? +} + +func GkrSolveHint(info constraint.GkrInfo, solvingData *GkrSolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + // assumes assignmentVector is arranged wire first, instance second in order of solution + circuit := info.Circuit + nbInstances := info.NbInstances + offsets := info.AssignmentOffsets() + assignment := solvingData.init(info) + chunks := circuit.Chunks(nbInstances) + + solveTask := func(chunkOffset int) utils.Task { + return func(startInChunk, endInChunk int) { + start := startInChunk + chunkOffset + end := endInChunk + chunkOffset + inputs := solvingData.memoryPool.Make(info.MaxNIns) + dependencyHeads := make([]int, len(circuit)) + for wI, w := range circuit { + dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { + return w.Dependencies[i].InputInstance + }, len(w.Dependencies), start) + } + + for instanceI := start; instanceI < end; instanceI++ { + for wireI, wire := range circuit { + if wire.IsInput() { + if dependencyHeads[wireI] < len(wire.Dependencies) && instanceI == wire.Dependencies[dependencyHeads[wireI]].InputInstance { + dep := wire.Dependencies[dependencyHeads[wireI]] + assignment[wireI][instanceI].Set(&assignment[dep.OutputWire][dep.OutputInstance]) + dependencyHeads[wireI]++ + } else { + assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) + } + } else { + // assemble the inputs + inputIndexes := info.Circuit[wireI].Inputs + for i, inputI := range inputIndexes { + inputs[i].Set(&assignment[inputI][instanceI]) + } + gate := solvingData.circuit[wireI].Gate + assignment[wireI][instanceI] = gate.Evaluate(inputs[:len(inputIndexes)]...) + } + } + } + solvingData.memoryPool.Dump(inputs) + } + } + + start := 0 + for _, end := range chunks { + solvingData.workers.Submit(end-start, solveTask(start), 1024).Wait() + start = end + } + + assignment.setOuts(info.Circuit, outs) + + return nil + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} + +func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { + + return func(_ *big.Int, ins, outs []*big.Int) error { + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + b := i.Bytes() + return b[:] + }) + + hsh := HashBuilderRegistry[hashName]() + + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + if err != nil { + return err + } + + // serialize proof: TODO: In gnark-crypto? + offset := 0 + for i := range proof { + for _, poly := range proof[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if proof[i].FinalEvalProof != nil { + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } + + data.dumpAssignments() + + return nil + + } +} + +var GkrGateRegistry = map[string]gkr.Gate{ // TODO: Migrate to gnark-crypto + "mul": mulGate(2), + "add": addGate{}, + "sub": subGate{}, + "neg": negGate{}, +} + +// TODO: Move to gnark-crypto +var HashBuilderRegistry = make(map[string]func() hash.Hash) + +type mulGate int +type addGate struct{} +type subGate struct{} +type negGate struct{} + +func (g mulGate) Evaluate(x ...fr.Element) (res fr.Element) { + if len(x) != int(g) { + panic("wrong input count") + } + switch len(x) { + case 0: + res.SetOne() + case 1: + res.Set(&x[0]) + default: + res.Mul(&x[0], &x[1]) + for i := 2; i < len(x); i++ { + res.Mul(&res, &x[2]) + } + } + return +} + +func (g mulGate) Degree() int { + return int(g) +} + +func (g addGate) Evaluate(x ...fr.Element) (res fr.Element) { + switch len(x) { + case 0: + // set zero + case 1: + res.Set(&x[0]) + case 2: + res.Add(&x[0], &x[1]) + for i := 2; i < len(x); i++ { + res.Add(&res, &x[2]) + } + } + return +} + +func (g addGate) Degree() int { + return 1 +} + +func (g subGate) Evaluate(element ...fr.Element) (diff fr.Element) { + if len(element) > 2 { + panic("not implemented") //TODO + } + diff.Sub(&element[0], &element[1]) + return +} + +func (g subGate) Degree() int { + return 1 +} + +func (g negGate) Evaluate(element ...fr.Element) (neg fr.Element) { + if len(element) != 1 { + panic("univariate gate") + } + neg.Neg(&element[0]) + return +} + +func (g negGate) Degree() int { + return 1 +} diff --git a/constraint/bw6-633/r1cs_test.go b/constraint/bw6-633/r1cs_test.go index 9f9d6a0d6f..7111c25c77 100644 --- a/constraint/bw6-633/r1cs_test.go +++ b/constraint/bw6-633/r1cs_test.go @@ -28,7 +28,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/consensys/gnark/constraint/bw6-633" + cs "github.com/consensys/gnark/constraint/bw6-633" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" ) diff --git a/constraint/bw6-633/system.go b/constraint/bw6-633/system.go index f102b47290..7cd6843fc5 100644 --- a/constraint/bw6-633/system.go +++ b/constraint/bw6-633/system.go @@ -372,3 +372,7 @@ func getTagSet() cbor.TagSet { return ts } + +func (s *system) AddGkr(gkr constraint.GkrInfo) error { + return s.System.AddGkr(gkr) +} diff --git a/constraint/bw6-761/gkr.go b/constraint/bw6-761/gkr.go new file mode 100644 index 0000000000..5d3edc00cf --- /dev/null +++ b/constraint/bw6-761/gkr.go @@ -0,0 +1,264 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by gnark DO NOT EDIT + +package cs + +import ( + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/gkr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" + hint "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/utils/algo_utils" + "hash" + "math/big" +) + +type GkrSolvingData struct { + assignments gkr.WireAssignment + circuit gkr.Circuit + memoryPool polynomial.Pool + workers *utils.WorkerPool +} + +func convertCircuit(noPtr constraint.GkrCircuit) gkr.Circuit { + resCircuit := make(gkr.Circuit, len(noPtr)) + for i := range noPtr { + resCircuit[i].Gate = GkrGateRegistry[noPtr[i].Gate] + resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) + } + return resCircuit +} + +func (d *GkrSolvingData) init(info constraint.GkrInfo) gkrAssignment { + d.circuit = convertCircuit(info.Circuit) + d.memoryPool = polynomial.NewPool(d.circuit.MemoryRequirements(info.NbInstances)...) + d.workers = utils.NewWorkerPool() + + assignmentsSequential := make(gkrAssignment, len(d.circuit)) + d.assignments = make(gkr.WireAssignment, len(d.circuit)) + for i := range assignmentsSequential { + assignmentsSequential[i] = d.memoryPool.Make(info.NbInstances) + d.assignments[&d.circuit[i]] = assignmentsSequential[i] + } + + return assignmentsSequential +} + +func (d *GkrSolvingData) dumpAssignments() { + for _, p := range d.assignments { + d.memoryPool.Dump(p) + } +} + +// this module assumes that wire and instance indexes respect dependencies + +type gkrAssignment [][]fr.Element //gkrAssignment is indexed wire first, instance second + +func (a gkrAssignment) setOuts(circuit constraint.GkrCircuit, outs []*big.Int) { + outsI := 0 + for i := range circuit { + if circuit[i].IsOutput() { + for j := range a[i] { + a[i][j].BigInt(outs[outsI]) + outsI++ + } + } + } + // Check if outsI == len(outs)? +} + +func GkrSolveHint(info constraint.GkrInfo, solvingData *GkrSolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + // assumes assignmentVector is arranged wire first, instance second in order of solution + circuit := info.Circuit + nbInstances := info.NbInstances + offsets := info.AssignmentOffsets() + assignment := solvingData.init(info) + chunks := circuit.Chunks(nbInstances) + + solveTask := func(chunkOffset int) utils.Task { + return func(startInChunk, endInChunk int) { + start := startInChunk + chunkOffset + end := endInChunk + chunkOffset + inputs := solvingData.memoryPool.Make(info.MaxNIns) + dependencyHeads := make([]int, len(circuit)) + for wI, w := range circuit { + dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { + return w.Dependencies[i].InputInstance + }, len(w.Dependencies), start) + } + + for instanceI := start; instanceI < end; instanceI++ { + for wireI, wire := range circuit { + if wire.IsInput() { + if dependencyHeads[wireI] < len(wire.Dependencies) && instanceI == wire.Dependencies[dependencyHeads[wireI]].InputInstance { + dep := wire.Dependencies[dependencyHeads[wireI]] + assignment[wireI][instanceI].Set(&assignment[dep.OutputWire][dep.OutputInstance]) + dependencyHeads[wireI]++ + } else { + assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) + } + } else { + // assemble the inputs + inputIndexes := info.Circuit[wireI].Inputs + for i, inputI := range inputIndexes { + inputs[i].Set(&assignment[inputI][instanceI]) + } + gate := solvingData.circuit[wireI].Gate + assignment[wireI][instanceI] = gate.Evaluate(inputs[:len(inputIndexes)]...) + } + } + } + solvingData.memoryPool.Dump(inputs) + } + } + + start := 0 + for _, end := range chunks { + solvingData.workers.Submit(end-start, solveTask(start), 1024).Wait() + start = end + } + + assignment.setOuts(info.Circuit, outs) + + return nil + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} + +func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { + + return func(_ *big.Int, ins, outs []*big.Int) error { + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + b := i.Bytes() + return b[:] + }) + + hsh := HashBuilderRegistry[hashName]() + + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + if err != nil { + return err + } + + // serialize proof: TODO: In gnark-crypto? + offset := 0 + for i := range proof { + for _, poly := range proof[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if proof[i].FinalEvalProof != nil { + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } + + data.dumpAssignments() + + return nil + + } +} + +var GkrGateRegistry = map[string]gkr.Gate{ // TODO: Migrate to gnark-crypto + "mul": mulGate(2), + "add": addGate{}, + "sub": subGate{}, + "neg": negGate{}, +} + +// TODO: Move to gnark-crypto +var HashBuilderRegistry = make(map[string]func() hash.Hash) + +type mulGate int +type addGate struct{} +type subGate struct{} +type negGate struct{} + +func (g mulGate) Evaluate(x ...fr.Element) (res fr.Element) { + if len(x) != int(g) { + panic("wrong input count") + } + switch len(x) { + case 0: + res.SetOne() + case 1: + res.Set(&x[0]) + default: + res.Mul(&x[0], &x[1]) + for i := 2; i < len(x); i++ { + res.Mul(&res, &x[2]) + } + } + return +} + +func (g mulGate) Degree() int { + return int(g) +} + +func (g addGate) Evaluate(x ...fr.Element) (res fr.Element) { + switch len(x) { + case 0: + // set zero + case 1: + res.Set(&x[0]) + case 2: + res.Add(&x[0], &x[1]) + for i := 2; i < len(x); i++ { + res.Add(&res, &x[2]) + } + } + return +} + +func (g addGate) Degree() int { + return 1 +} + +func (g subGate) Evaluate(element ...fr.Element) (diff fr.Element) { + if len(element) > 2 { + panic("not implemented") //TODO + } + diff.Sub(&element[0], &element[1]) + return +} + +func (g subGate) Degree() int { + return 1 +} + +func (g negGate) Evaluate(element ...fr.Element) (neg fr.Element) { + if len(element) != 1 { + panic("univariate gate") + } + neg.Neg(&element[0]) + return +} + +func (g negGate) Degree() int { + return 1 +} diff --git a/constraint/bw6-761/r1cs_test.go b/constraint/bw6-761/r1cs_test.go index b6eae1220c..d8f3b48039 100644 --- a/constraint/bw6-761/r1cs_test.go +++ b/constraint/bw6-761/r1cs_test.go @@ -28,7 +28,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/consensys/gnark/constraint/bw6-761" + cs "github.com/consensys/gnark/constraint/bw6-761" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" ) diff --git a/constraint/bw6-761/system.go b/constraint/bw6-761/system.go index 10e7990dfa..75252436e8 100644 --- a/constraint/bw6-761/system.go +++ b/constraint/bw6-761/system.go @@ -372,3 +372,7 @@ func getTagSet() cbor.TagSet { return ts } + +func (s *system) AddGkr(gkr constraint.GkrInfo) error { + return s.System.AddGkr(gkr) +} diff --git a/constraint/core.go b/constraint/core.go index e54fab3149..12aac144cb 100644 --- a/constraint/core.go +++ b/constraint/core.go @@ -125,6 +125,7 @@ type System struct { lbOutputs []uint32 `cbor:"-"` // wire outputs for current constraint. CommitmentInfo Commitments + GkrInfo GkrInfo genericHint BlueprintID } @@ -457,3 +458,12 @@ func putBuffer(buf *[]uint32) { } bufPool.Put(buf) } + +func (system *System) AddGkr(gkr GkrInfo) error { + if system.GkrInfo.Is() { + return fmt.Errorf("currently only one GKR sub-circuit per SNARK is supported") + } + + system.GkrInfo = gkr + return nil +} diff --git a/constraint/gkr.go b/constraint/gkr.go new file mode 100644 index 0000000000..4d84a5466d --- /dev/null +++ b/constraint/gkr.go @@ -0,0 +1,158 @@ +package constraint + +import ( + "fmt" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/utils/algo_utils" + "sort" +) + +type GkrVariable int // Just an alias to hide implementation details. May be more trouble than worth + +type InputDependency struct { + OutputWire int + OutputInstance int + InputInstance int +} + +type GkrWire struct { + Gate string // TODO: Change to description + Inputs []int + Dependencies []InputDependency // nil for input wires + NbUniqueOutputs int +} + +type GkrCircuit []GkrWire + +type GkrInfo struct { + Circuit GkrCircuit + MaxNIns int + NbInstances int + HashName string + SolveHintID solver.HintID + ProveHintID solver.HintID +} + +type GkrPermutations struct { + SortedInstances []int + SortedWires []int + InstancesPermutation []int + WiresPermutation []int +} + +func (w GkrWire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w GkrWire) IsOutput() bool { + return w.NbUniqueOutputs == 0 +} + +// AssignmentOffsets returns the index of the first value assigned to a wire TODO: Explain clearly +func (d *GkrInfo) AssignmentOffsets() []int { + c := d.Circuit + res := make([]int, len(c)+1) + for i := range c { + nbExplicitAssignments := 0 + if c[i].IsInput() { + nbExplicitAssignments = d.NbInstances - len(c[i].Dependencies) + } + res[i+1] = res[i] + nbExplicitAssignments + } + return res +} + +func (d *GkrInfo) NewInputVariable() GkrVariable { + i := len(d.Circuit) + d.Circuit = append(d.Circuit, GkrWire{}) + return GkrVariable(i) +} + +// Compile sorts the circuit wires, their dependencies and the instances +func (d *GkrInfo) Compile(nbInstances int) (GkrPermutations, error) { + + var p GkrPermutations + d.NbInstances = nbInstances + // sort the instances to decide the order in which they are to be solved + instanceDeps := make([][]int, nbInstances) + for i := range d.Circuit { + for _, dep := range d.Circuit[i].Dependencies { + instanceDeps[dep.InputInstance] = append(instanceDeps[dep.InputInstance], dep.OutputInstance) + } + } + + p.SortedInstances, _ = algo_utils.TopologicalSort(instanceDeps) + p.InstancesPermutation = algo_utils.InvertPermutation(p.SortedInstances) + + // this whole circuit sorting is a bit of a charade. if things are built using an api, there's no way it could NOT already be topologically sorted + // worth keeping for future-proofing? + + inputs := algo_utils.Map(d.Circuit, func(w GkrWire) []int { + return w.Inputs + }) + + var uniqueOuts [][]int + p.SortedWires, uniqueOuts = algo_utils.TopologicalSort(inputs) + p.WiresPermutation = algo_utils.InvertPermutation(p.SortedWires) + wirePermutationAt := algo_utils.SliceAt(p.WiresPermutation) + sorted := make([]GkrWire, len(d.Circuit)) // TODO: Directly manipulate d.Circuit instead + for newI, oldI := range p.SortedWires { + oldW := d.Circuit[oldI] + + if !oldW.IsInput() { + d.MaxNIns = utils.Max(d.MaxNIns, len(oldW.Inputs)) + } + + for j := range oldW.Dependencies { + dep := &oldW.Dependencies[j] + dep.OutputWire = p.WiresPermutation[dep.OutputWire] + dep.InputInstance = p.InstancesPermutation[dep.InputInstance] + dep.OutputInstance = p.InstancesPermutation[dep.OutputInstance] + } + sort.Slice(oldW.Dependencies, func(i, j int) bool { + return oldW.Dependencies[i].InputInstance < oldW.Dependencies[j].InputInstance + }) + for i := 1; i < len(oldW.Dependencies); i++ { + if oldW.Dependencies[i].InputInstance == oldW.Dependencies[i-1].InputInstance { + return p, fmt.Errorf("an input wire can only have one dependency per instance") + } + } // TODO: Check that dependencies and explicit assignments cover all instances + + sorted[newI] = GkrWire{ + Gate: oldW.Gate, + Inputs: algo_utils.Map(oldW.Inputs, wirePermutationAt), + Dependencies: oldW.Dependencies, + NbUniqueOutputs: len(uniqueOuts[oldI]), + } + } + d.Circuit = sorted + + return p, nil +} + +func (d *GkrInfo) Is() bool { + return d.Circuit != nil +} + +// Chunks returns intervals of instances that are independent of each other and can be solved in parallel +func (c GkrCircuit) Chunks(nbInstances int) []int { + res := make([]int, 0, 1) + lastSeenDependencyI := make([]int, len(c)) + + for start, end := 0, 0; start != nbInstances; start = end { + end = nbInstances + endWireI := -1 + for wI, w := range c { + if wDepI := lastSeenDependencyI[wI]; wDepI < len(w.Dependencies) && w.Dependencies[wDepI].InputInstance < end { + end = w.Dependencies[wDepI].InputInstance + endWireI = wI + } + } + if endWireI != -1 { + lastSeenDependencyI[endWireI]++ + } + res = append(res, end) + } + return res +} diff --git a/constraint/system.go b/constraint/system.go index 19f8167f32..e03586af4e 100644 --- a/constraint/system.go +++ b/constraint/system.go @@ -20,7 +20,7 @@ type ConstraintSystem interface { // Deprecated: use _, err := Solve(...) instead IsSolved(witness witness.Witness, opts ...solver.Option) error - // Solve attempts to solves the constraint system using provided witness. + // Solve attempts to solve the constraint system using provided witness. // Returns an error if the witness does not allow all the constraints to be satisfied. // Returns a typed solution (R1CSSolution or SparseR1CSSolution) and nil otherwise. Solve(witness witness.Witness, opts ...solver.Option) (any, error) @@ -52,6 +52,7 @@ type ConstraintSystem interface { AddCommitment(c Commitment) error GetCommitments() Commitments + AddGkr(gkr GkrInfo) error AddLog(l LogEntry) diff --git a/constraint/term.go b/constraint/term.go index 857edd0da0..799fb4f7a3 100644 --- a/constraint/term.go +++ b/constraint/term.go @@ -54,7 +54,7 @@ func (t Term) String(r Resolver) string { return sbb.String() } -// implements constraint.Compressable +// implements constraint.Compressible // Compress compresses the term into a slice of uint32 words. // For compatibility with test engine and LinearExpression, the term is encoded as: diff --git a/constraint/tinyfield/system.go b/constraint/tinyfield/system.go index d75471e445..bcb551272b 100644 --- a/constraint/tinyfield/system.go +++ b/constraint/tinyfield/system.go @@ -372,3 +372,7 @@ func getTagSet() cbor.TagSet { return ts } + +func (s *system) AddGkr(gkr constraint.GkrInfo) error { + return s.System.AddGkr(gkr) +} diff --git a/frontend/builder.go b/frontend/builder.go index 680a789b7a..717dbfca7a 100644 --- a/frontend/builder.go +++ b/frontend/builder.go @@ -64,6 +64,8 @@ type Compiler interface { // ToCanonicalVariable converts a frontend.Variable to a constraint system specific Variable // ! Experimental: use in conjunction with constraint.CustomizableSystem ToCanonicalVariable(Variable) CanonicalVariable + + SetGkrInfo(constraint.GkrInfo) error } // Builder represents a constraint system builder @@ -104,5 +106,5 @@ type Rangechecker interface { // a PLONK builder --> constraint.Term // and the test/Engine --> ~*big.Int. type CanonicalVariable interface { - constraint.Compressable + constraint.Compressible } diff --git a/frontend/cs/r1cs/api.go b/frontend/cs/r1cs/api.go index 7bb01fa76c..797bba0854 100644 --- a/frontend/cs/r1cs/api.go +++ b/frontend/cs/r1cs/api.go @@ -815,3 +815,7 @@ func (builder *builder) wireIDsToVars(wireIDs ...[]int) []frontend.Variable { } return res } + +func (builder *builder) SetGkrInfo(info constraint.GkrInfo) error { + return builder.cs.AddGkr(info) +} diff --git a/frontend/cs/scs/api.go b/frontend/cs/scs/api.go index 12eff1e730..75764878d3 100644 --- a/frontend/cs/scs/api.go +++ b/frontend/cs/scs/api.go @@ -619,3 +619,7 @@ func filterConstants(v []frontend.Variable) []frontend.Variable { func (*builder) FrontendType() frontendtype.Type { return frontendtype.SCS } + +func (builder *builder) SetGkrInfo(info constraint.GkrInfo) error { + return builder.cs.AddGkr(info) +} diff --git a/internal/generator/backend/main.go b/internal/generator/backend/main.go index 4e79e1eb88..b5c0bfe1a8 100644 --- a/internal/generator/backend/main.go +++ b/internal/generator/backend/main.go @@ -129,6 +129,14 @@ func main() { panic(err) } + // gkr backend + if d.Curve != "tinyfield" { + entries = []bavard.Entry{{File: filepath.Join(csDir, "gkr.go"), Templates: []string{"gkr.go.tmpl", importCurve}}} + if err := bgen.Generate(d, "cs", "./template/representations/", entries...); err != nil { + panic(err) + } + } + entries = []bavard.Entry{ {File: filepath.Join(csDir, "r1cs_test.go"), Templates: []string{"tests/r1cs.go.tmpl", importCurve}}, } diff --git a/internal/generator/backend/template/imports.go.tmpl b/internal/generator/backend/template/imports.go.tmpl index 56639bdfb2..949969f9d7 100644 --- a/internal/generator/backend/template/imports.go.tmpl +++ b/internal/generator/backend/template/imports.go.tmpl @@ -22,7 +22,7 @@ {{- if eq .Curve "tinyfield"}} "github.com/consensys/gnark/constraint/tinyfield" {{- else}} - "github.com/consensys/gnark/constraint/{{toLower .Curve}}" + cs "github.com/consensys/gnark/constraint/{{toLower .Curve}}" {{- end}} {{- end }} diff --git a/internal/generator/backend/template/representations/gkr.go.tmpl b/internal/generator/backend/template/representations/gkr.go.tmpl new file mode 100644 index 0000000000..9c66dd8cf0 --- /dev/null +++ b/internal/generator/backend/template/representations/gkr.go.tmpl @@ -0,0 +1,246 @@ +import ( + {{- template "import_fr" .}} + {{- template "import_gkr" .}} + {{- template "import_polynomial" .}} + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/constraint" + hint "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/utils/algo_utils" + "hash" + "math/big" +) + +type GkrSolvingData struct { + assignments gkr.WireAssignment + circuit gkr.Circuit + memoryPool polynomial.Pool + workers *utils.WorkerPool +} + +func convertCircuit(noPtr constraint.GkrCircuit) gkr.Circuit { + resCircuit := make(gkr.Circuit, len(noPtr)) + for i := range noPtr { + resCircuit[i].Gate = GkrGateRegistry[noPtr[i].Gate] + resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) + } + return resCircuit +} + +func (d *GkrSolvingData) init(info constraint.GkrInfo) gkrAssignment { + d.circuit = convertCircuit(info.Circuit) + d.memoryPool = polynomial.NewPool(d.circuit.MemoryRequirements(info.NbInstances)...) + d.workers = utils.NewWorkerPool() + + assignmentsSequential := make(gkrAssignment, len(d.circuit)) + d.assignments = make(gkr.WireAssignment, len(d.circuit)) + for i := range assignmentsSequential { + assignmentsSequential[i] = d.memoryPool.Make(info.NbInstances) + d.assignments[&d.circuit[i]] = assignmentsSequential[i] + } + + return assignmentsSequential +} + +func (d *GkrSolvingData) dumpAssignments() { + for _, p := range d.assignments { + d.memoryPool.Dump(p) + } +} + +// this module assumes that wire and instance indexes respect dependencies + +type gkrAssignment [][]fr.Element //gkrAssignment is indexed wire first, instance second + +func (a gkrAssignment) setOuts(circuit constraint.GkrCircuit, outs []*big.Int) { + outsI := 0 + for i := range circuit { + if circuit[i].IsOutput() { + for j := range a[i] { + a[i][j].BigInt(outs[outsI]) + outsI++ + } + } + } + // Check if outsI == len(outs)? +} + +func GkrSolveHint(info constraint.GkrInfo, solvingData *GkrSolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + // assumes assignmentVector is arranged wire first, instance second in order of solution + circuit := info.Circuit + nbInstances := info.NbInstances + offsets := info.AssignmentOffsets() + assignment := solvingData.init(info) + chunks := circuit.Chunks(nbInstances) + + solveTask := func(chunkOffset int) utils.Task { + return func(startInChunk, endInChunk int) { + start := startInChunk + chunkOffset + end := endInChunk + chunkOffset + inputs := solvingData.memoryPool.Make(info.MaxNIns) + dependencyHeads := make([]int, len(circuit)) + for wI, w := range circuit { + dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { + return w.Dependencies[i].InputInstance + }, len(w.Dependencies), start) + } + + for instanceI := start; instanceI < end; instanceI++ { + for wireI, wire := range circuit { + if wire.IsInput() { + if dependencyHeads[wireI] < len(wire.Dependencies) && instanceI == wire.Dependencies[dependencyHeads[wireI]].InputInstance { + dep := wire.Dependencies[dependencyHeads[wireI]] + assignment[wireI][instanceI].Set(&assignment[dep.OutputWire][dep.OutputInstance]) + dependencyHeads[wireI]++ + } else { + assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) + } + } else { + // assemble the inputs + inputIndexes := info.Circuit[wireI].Inputs + for i, inputI := range inputIndexes { + inputs[i].Set(&assignment[inputI][instanceI]) + } + gate := solvingData.circuit[wireI].Gate + assignment[wireI][instanceI] = gate.Evaluate(inputs[:len(inputIndexes)]...) + } + } + } + solvingData.memoryPool.Dump(inputs) + } + } + + start := 0 + for _, end := range chunks { + solvingData.workers.Submit(end-start, solveTask(start), 1024).Wait() + start = end + } + + assignment.setOuts(info.Circuit, outs) + + return nil + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} + +func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { + + return func(_ *big.Int, ins, outs []*big.Int) error { + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + b := i.Bytes() + return b[:] + }) + + hsh := HashBuilderRegistry[hashName]() + + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + if err != nil { + return err + } + + // serialize proof: TODO: In gnark-crypto? + offset := 0 + for i := range proof { + for _, poly := range proof[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if proof[i].FinalEvalProof != nil { + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } + + data.dumpAssignments() + + return nil + + } +} + +var GkrGateRegistry = map[string]gkr.Gate{ // TODO: Migrate to gnark-crypto + "mul": mulGate(2), + "add": addGate{}, + "sub": subGate{}, + "neg": negGate{}, +} + +// TODO: Move to gnark-crypto +var HashBuilderRegistry = make(map[string]func() hash.Hash) + +type mulGate int +type addGate struct{} +type subGate struct{} +type negGate struct{} + +func (g mulGate) Evaluate(x ...fr.Element) (res fr.Element) { + if len(x) != int(g) { + panic("wrong input count") + } + switch len(x) { + case 0: + res.SetOne() + case 1: + res.Set(&x[0]) + default: + res.Mul(&x[0], &x[1]) + for i := 2; i < len(x); i++ { + res.Mul(&res, &x[2]) + } + } + return +} + +func (g mulGate) Degree() int { + return int(g) +} + +func (g addGate) Evaluate(x ...fr.Element) (res fr.Element) { + switch len(x) { + case 0: + // set zero + case 1: + res.Set(&x[0]) + case 2: + res.Add(&x[0], &x[1]) + for i := 2; i < len(x); i++ { + res.Add(&res, &x[2]) + } + } + return +} + +func (g addGate) Degree() int { + return 1 +} + +func (g subGate) Evaluate(element ...fr.Element) (diff fr.Element) { + if len(element) > 2 { + panic("not implemented") //TODO + } + diff.Sub(&element[0], &element[1]) + return +} + +func (g subGate) Degree() int { + return 1 +} + +func (g negGate) Evaluate(element ...fr.Element) (neg fr.Element) { + if len(element) != 1 { + panic("univariate gate") + } + neg.Neg(&element[0]) + return +} + +func (g negGate) Degree() int { + return 1 +} \ No newline at end of file diff --git a/internal/generator/backend/template/representations/system.go.tmpl b/internal/generator/backend/template/representations/system.go.tmpl index 7288fdb546..dc5b4a98d3 100644 --- a/internal/generator/backend/template/representations/system.go.tmpl +++ b/internal/generator/backend/template/representations/system.go.tmpl @@ -369,3 +369,7 @@ func getTagSet() cbor.TagSet { return ts } + +func (s *system) AddGkr(gkr constraint.GkrInfo) error { + return s.System.AddGkr(gkr) +} \ No newline at end of file diff --git a/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl index 760cf22d24..fa72568220 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl @@ -77,6 +77,13 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b }(i))) } + if r1cs.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + solverOpts = append(solverOpts, + solver.OverrideHint(r1cs.GkrInfo.SolveHintID, cs.GkrSolveHint(r1cs.GkrInfo, &gkrData)), + solver.OverrideHint(r1cs.GkrInfo.ProveHintID, cs.GkrProveHint(r1cs.GkrInfo.HashName, &gkrData))) + } + _solution, err := r1cs.Solve(fullWitness, solverOpts...) if err != nil { return nil, err diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl index a869528cf9..def44c21ec 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl @@ -105,6 +105,13 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts bsb22ComputeCommitmentHint(spr, pk, proof, cCommitments, &commitmentVal[i], i))) } + if spr.GkrInfo.Is() { + var gkrData cs.GkrSolvingData + opt.SolverOpts = append(opt.SolverOpts, + solver.OverrideHint(spr.GkrInfo.SolveHintID, cs.GkrSolveHint(spr.GkrInfo, &gkrData)), + solver.OverrideHint(spr.GkrInfo.ProveHintID, cs.GkrProveHint(spr.GkrInfo.HashName, &gkrData))) + } + // query l, r, o in Lagrange basis, not blinded _solution, err := spr.Solve(fullWitness, opt.SolverOpts...) if err != nil { diff --git a/std/fiat-shamir/transcript.go b/std/fiat-shamir/transcript.go index a8d82547d0..63fe8e786d 100644 --- a/std/fiat-shamir/transcript.go +++ b/std/fiat-shamir/transcript.go @@ -18,6 +18,7 @@ package fiatshamir import ( "errors" + "github.com/consensys/gnark/constant" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/hash" diff --git a/std/gkr/api.go b/std/gkr/api.go new file mode 100644 index 0000000000..ad67fb6fcb --- /dev/null +++ b/std/gkr/api.go @@ -0,0 +1,50 @@ +package gkr + +import ( + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/std/utils/algo_utils" +) + +func frontendVarToInt(a constraint.GkrVariable) int { + return int(a) +} + +func (api *API) newNonInputVariable(gate string, in []constraint.GkrVariable) constraint.GkrVariable { + api.toStore.Circuit = append(api.toStore.Circuit, constraint.GkrWire{ + Gate: gate, + Inputs: algo_utils.Map(in, frontendVarToInt), + }) + api.assignments = append(api.assignments, nil) + return constraint.GkrVariable(len(api.toStore.Circuit) - 1) +} + +func (api *API) newVar2PlusIn(gate string, in1, in2 constraint.GkrVariable, in ...constraint.GkrVariable) constraint.GkrVariable { + inCombined := make([]constraint.GkrVariable, 2+len(in)) + inCombined[0] = in1 + inCombined[1] = in2 + for i := range in { + inCombined[i+2] = in[i] + } + return api.newNonInputVariable(gate, inCombined) +} + +func (api *API) Add(i1, i2 constraint.GkrVariable, in ...constraint.GkrVariable) constraint.GkrVariable { + return api.newVar2PlusIn("add", i1, i2, in...) +} + +func (api *API) Neg(i1 constraint.GkrVariable) constraint.GkrVariable { + return api.newNonInputVariable("neg", []constraint.GkrVariable{i1}) +} + +func (api *API) Sub(i1, i2 constraint.GkrVariable, in ...constraint.GkrVariable) constraint.GkrVariable { + return api.newVar2PlusIn("sub", i1, i2, in...) +} + +func (api *API) Mul(i1, i2 constraint.GkrVariable, in ...constraint.GkrVariable) constraint.GkrVariable { + return api.newVar2PlusIn("mul", i1, i2, in...) +} + +// TODO @Tabaie This can be useful +func (api *API) Println(a ...constraint.GkrVariable) { + panic("not implemented") +} diff --git a/std/gkr/api_test.go b/std/gkr/api_test.go new file mode 100644 index 0000000000..a41670be0e --- /dev/null +++ b/std/gkr/api_test.go @@ -0,0 +1,661 @@ +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/kzg" + "github.com/consensys/gnark/backend/plonk" + "github.com/consensys/gnark/test" + "github.com/stretchr/testify/require" + "hash" + "math/rand" + "strconv" + "testing" + "time" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + bn254MiMC "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/constraint" + bn254r1cs "github.com/consensys/gnark/constraint/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" + stdHash "github.com/consensys/gnark/std/hash" + "github.com/consensys/gnark/std/hash/mimc" + test_vector_utils "github.com/consensys/gnark/std/utils/test_vectors_utils" +) + +// compressThreshold --> if linear expressions are larger than this, the frontend will introduce +// intermediate constraints. The lower this number is, the faster compile time should be (to a point) +// but resulting circuit will have more constraints (slower proving time). +const compressThreshold = 1000 + +type doubleNoDependencyCircuit struct { + X []frontend.Variable + hashName string +} + +func (c *doubleNoDependencyCircuit) Define(api frontend.API) error { + gkr := NewApi() + var x constraint.GkrVariable + var err error + if x, err = gkr.Import(c.X); err != nil { + return err + } + z := gkr.Add(x, x) + var solution Solution + if solution, err = gkr.Solve(api); err != nil { + return err + } + Z := solution.Export(z) + + for i := range Z { + api.AssertIsEqual(Z[i], api.Mul(2, c.X[i])) + } + + return solution.Verify(c.hashName) +} + +func TestDoubleNoDependencyCircuit(t *testing.T) { + + xValuess := [][]frontend.Variable{ + {1, 1}, + {1, 2}, + } + + hashes := []string{"-1", "-20"} + + for _, xValues := range xValuess { + for _, hashName := range hashes { + assignment := doubleNoDependencyCircuit{X: xValues} + circuit := doubleNoDependencyCircuit{X: make([]frontend.Variable, len(xValues)), hashName: hashName} + + testGroth16(t, &circuit, &assignment) + testPlonk(t, &circuit, &assignment) + } + } +} + +type sqNoDependencyCircuit struct { + X []frontend.Variable + hashName string +} + +func (c *sqNoDependencyCircuit) Define(api frontend.API) error { + gkr := NewApi() + var x constraint.GkrVariable + var err error + if x, err = gkr.Import(c.X); err != nil { + return err + } + z := gkr.Mul(x, x) + var solution Solution + if solution, err = gkr.Solve(api); err != nil { + return err + } + Z := solution.Export(z) + + for i := range Z { + api.AssertIsEqual(Z[i], api.Mul(c.X[i], c.X[i])) + } + + return solution.Verify(c.hashName) +} + +func TestSqNoDependencyCircuit(t *testing.T) { + + xValuess := [][]frontend.Variable{ + {1, 1}, + {1, 2}, + } + + hashes := []string{"-1", "-20"} + + for _, xValues := range xValuess { + for _, hashName := range hashes { + assignment := sqNoDependencyCircuit{X: xValues} + circuit := sqNoDependencyCircuit{X: make([]frontend.Variable, len(xValues)), hashName: hashName} + testGroth16(t, &circuit, &assignment) + testPlonk(t, &circuit, &assignment) + } + } +} + +type mulNoDependencyCircuit struct { + X, Y []frontend.Variable + hashName string +} + +func (c *mulNoDependencyCircuit) Define(api frontend.API) error { + gkr := NewApi() + var x, y constraint.GkrVariable + var err error + if x, err = gkr.Import(c.X); err != nil { + return err + } + if y, err = gkr.Import(c.Y); err != nil { + return err + } + z := gkr.Mul(x, y) + var solution Solution + if solution, err = gkr.Solve(api); err != nil { + return err + } + X := solution.Export(x) + Y := solution.Export(y) + Z := solution.Export(z) + api.Println("after solving, z=", Z, ", x=", X, ", y=", Y) + + for i := range c.X { + api.Println("z@", i, " = ", Z[i]) + api.Println("x.y = ", api.Mul(c.X[i], c.Y[i])) + api.AssertIsEqual(Z[i], api.Mul(c.X[i], c.Y[i])) + } + + return solution.Verify(c.hashName) +} + +func TestMulNoDependency(t *testing.T) { + xValuess := [][]frontend.Variable{ + {1, 2}, + } + yValuess := [][]frontend.Variable{ + {0, 3}, + } + + hashes := []string{"-1", "-20"} + + for i := range xValuess { + for _, hashName := range hashes { + + assignment := mulNoDependencyCircuit{ + X: xValuess[i], + Y: yValuess[i], + } + circuit := mulNoDependencyCircuit{ + X: make([]frontend.Variable, len(xValuess[i])), + Y: make([]frontend.Variable, len(yValuess[i])), + hashName: hashName, + } + + testGroth16(t, &circuit, &assignment) + testPlonk(t, &circuit, &assignment) + } + } +} + +type mulWithDependencyCircuit struct { + XLast frontend.Variable + Y []frontend.Variable + hashName string +} + +func (c *mulWithDependencyCircuit) Define(api frontend.API) error { + gkr := NewApi() + var x, y constraint.GkrVariable + var err error + + X := make([]frontend.Variable, len(c.Y)) + X[len(c.Y)-1] = c.XLast + if x, err = gkr.Import(X); err != nil { + return err + } + if y, err = gkr.Import(c.Y); err != nil { + return err + } + z := gkr.Mul(x, y) + + for i := len(X) - 1; i > 0; i-- { + gkr.Series(x, z, i-1, i) + } + + var solution Solution + if solution, err = gkr.Solve(api); err != nil { + return err + } + X = solution.Export(x) + Y := solution.Export(y) + Z := solution.Export(z) + + api.Println("after solving, z=", Z, ", x=", X, ", y=", Y) + + lastI := len(X) - 1 + api.AssertIsEqual(Z[lastI], api.Mul(c.XLast, Y[lastI])) + for i := 0; i < lastI; i++ { + api.AssertIsEqual(Z[i], api.Mul(Z[i+1], Y[i])) + } + return solution.Verify(c.hashName) +} + +func TestSolveMulWithDependency(t *testing.T) { + assignment := mulWithDependencyCircuit{ + XLast: 1, + Y: []frontend.Variable{3, 2}, + } + circuit := mulWithDependencyCircuit{Y: make([]frontend.Variable, len(assignment.Y)), hashName: "-20"} + + testGroth16(t, &circuit, &assignment) + testPlonk(t, &circuit, &assignment) +} + +func TestApiMul(t *testing.T) { + var ( + x constraint.GkrVariable + y constraint.GkrVariable + z constraint.GkrVariable + err error + ) + api := NewApi() + x, err = api.Import([]frontend.Variable{nil, nil}) + require.NoError(t, err) + y, err = api.Import([]frontend.Variable{nil, nil}) + require.NoError(t, err) + z = api.Mul(x, y) + test_vector_utils.AssertSliceEqual(t, api.toStore.Circuit[z].Inputs, []int{int(x), int(y)}) // TODO: Find out why assert.Equal gives false positives ( []*Wire{x,x} as second argument passes when it shouldn't ) +} + +func BenchmarkMiMCMerkleTree(b *testing.B) { + depth := 14 + //fmt.Println("start") + bottom := make([]frontend.Variable, 1<= 0; d-- { + for i := 0; i < 1<= 0 { + cached, slice[j] = slice[j], cached + j, permutation[j] = permutation[j], ^permutation[j] + } + permutation[next] = ^permutation[next] + } + for i := range permutation { + permutation[i] = ^permutation[i] + } +} + +func Map[T, S any](in []T, f func(T) S) []S { + out := make([]S, len(in)) + for i, t := range in { + out[i] = f(t) + } + return out +} + +func MapRange[S any](begin, end int, f func(int) S) []S { + out := make([]S, end-begin) + for i := begin; i < end; i++ { + out[i] = f(i) + } + return out +} + +func SliceAt[T any](slice []T) func(int) T { + return func(i int) T { + return slice[i] + } +} + +func SlicePtrAt[T any](slice []T) func(int) *T { + return func(i int) *T { + return &slice[i] + } +} + +func MapAt[K comparable, V any](mp map[K]V) func(K) V { + return func(k K) V { + return mp[k] + } +} + +// InvertPermutation input permutation must contain exactly 0, ..., len(permutation)-1 +func InvertPermutation(permutation []int) []int { + res := make([]int, len(permutation)) + for i := range permutation { + res[permutation[i]] = i + } + return res +} + +// TODO: Move this to gnark-crypto and use it for gkr there as well + +// TopologicalSort takes a list of lists of dependencies and proposes a sorting of the lists in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// As a bonus, it returns for each list its "unique" outputs. That is, a list of its outputs with no duplicates. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input. +// If performance was bad, consider using a heap for finding the value "leastReady". +// WARNING: Due to the current implementation of intSet, it is ALWAYS O(n^2). +func TopologicalSort(inputs [][]int) (sorted []int, uniqueOutputs [][]int) { + data := newTopSortData(inputs) + sorted = make([]int, len(inputs)) + + for i := range inputs { + sorted[i] = data.leastReady + data.markDone(data.leastReady) + } + + return sorted, data.uniqueOutputs +} + +type topSortData struct { + uniqueOutputs [][]int + inputs [][]int + status []int // status > 0 indicates number of unique inputs left to be ready. status = 0 means ready. status = -1 means done + leastReady int +} + +func newTopSortData(inputs [][]int) topSortData { + size := len(inputs) + res := topSortData{ + uniqueOutputs: make([][]int, size), + inputs: inputs, + status: make([]int, size), + leastReady: 0, + } + + inputsISet := bitset.New(uint(size)) + for i := range res.uniqueOutputs { + if i != 0 { + inputsISet.ClearAll() + } + cpt := 0 + for _, in := range inputs[i] { + if !inputsISet.Test(uint(in)) { + inputsISet.Set(uint(in)) + cpt++ + res.uniqueOutputs[in] = append(res.uniqueOutputs[in], i) + } + } + res.status[i] = cpt + } + + for res.status[res.leastReady] != 0 { + res.leastReady++ + } + + return res +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.uniqueOutputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +// BinarySearch looks for toFind in a sorted slice, and returns the index at which it either is or would be were it to be inserted. +func BinarySearch(slice []int, toFind int) int { + var start int + for end := len(slice); start != end; { + mid := (start + end) / 2 + if toFind >= slice[mid] { + start = mid + } + if toFind <= slice[mid] { + end = mid + } + } + return start +} + +// BinarySearchFunc looks for toFind in an increasing function of domain 0 ... (end-1), and returns the index at which it either is or would be were it to be inserted. +func BinarySearchFunc(eval func(int) int, end int, toFind int) int { + var start int + for start != end { + mid := (start + end) / 2 + val := eval(mid) + if toFind >= val { + start = mid + } + if toFind <= val { + end = mid + } + } + return start +} diff --git a/std/utils/algo_utils/algo_utils_test.go b/std/utils/algo_utils/algo_utils_test.go new file mode 100644 index 0000000000..85ab4bf294 --- /dev/null +++ b/std/utils/algo_utils/algo_utils_test.go @@ -0,0 +1,69 @@ +package algo_utils + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func SliceLen[T any](slice []T) int { + return len(slice) +} + +func testTopSort(t *testing.T, inputs [][]int, expectedSorted, expectedNbUniqueOuts []int) { + sorted, uniqueOuts := TopologicalSort(inputs) + nbUniqueOut := Map(uniqueOuts, SliceLen[int]) + assert.Equal(t, expectedSorted, sorted) + assert.Equal(t, expectedNbUniqueOuts, nbUniqueOut) +} + +func TestTopSortTrivial(t *testing.T) { + testTopSort(t, [][]int{ + {1}, + {}, + }, []int{1, 0}, []int{0, 1}) +} + +func TestTopSortSingleGate(t *testing.T) { + inputs := [][]int{{1, 2}, {}, {}} + expectedSorted := []int{1, 2, 0} + expectedNbUniqueOuts := []int{0, 1, 1} + testTopSort(t, inputs, expectedSorted, expectedNbUniqueOuts) +} + +func TestTopSortDeep(t *testing.T) { + inputs := [][]int{{2}, {3}, {}, {0}} + expectedSorted := []int{2, 0, 3, 1} + expectedNbUniqueOuts := []int{1, 0, 1, 1} + + testTopSort(t, inputs, expectedSorted, expectedNbUniqueOuts) +} + +func TestTopSortWide(t *testing.T) { + inputs := [][]int{ + {3, 8}, + {6}, + {4}, + {}, + {}, + {9}, + {9}, + {9, 5, 2, 2}, + {4, 3}, + {}, + } + expectedSorted := []int{3, 4, 2, 8, 0, 9, 5, 6, 1, 7} + expectedNbUniqueOut := []int{0, 0, 1, 2, 2, 1, 1, 0, 1, 3} + + testTopSort(t, inputs, expectedSorted, expectedNbUniqueOut) +} + +func TestPermute(t *testing.T) { + list := []int{34, 65, 23, 2, 5} + permutation := []int{2, 0, 1, 4, 3} + permutationCopy := make([]int, len(permutation)) + copy(permutationCopy, permutation) + + Permute(list, permutation) + assert.Equal(t, []int{65, 23, 34, 5, 2}, list) + assert.Equal(t, permutationCopy, permutation) +} diff --git a/std/utils/test_vectors_utils/test_vector_utils.go b/std/utils/test_vectors_utils/test_vector_utils.go new file mode 100644 index 0000000000..2f5dbc4a38 --- /dev/null +++ b/std/utils/test_vectors_utils/test_vector_utils.go @@ -0,0 +1,260 @@ +package test_vector_utils + +import ( + "encoding/json" + "github.com/consensys/gnark/frontend" + "github.com/stretchr/testify/assert" + "os" + "path/filepath" + "strconv" + "strings" + "testing" +) + +// These data structures fail to equate different representations of the same number. i.e. 5 = -10/-2 +// @Tabaie TODO Replace with proper lookup tables + +type Map struct { + keys []frontend.Variable + values []frontend.Variable +} + +func getDelta(api frontend.API, x frontend.Variable, deltaIndex int, keys []frontend.Variable) frontend.Variable { + num := frontend.Variable(1) + den := frontend.Variable(1) + + for i, key := range keys { + if i != deltaIndex { + num = api.Mul(num, api.Sub(key, x)) + den = api.Mul(den, api.Sub(key, keys[deltaIndex])) + } + } + + return api.Div(num, den) +} + +// Get returns garbage if key is not present +func (m Map) Get(api frontend.API, key frontend.Variable) frontend.Variable { + res := frontend.Variable(0) + + for i := range m.keys { + deltaI := getDelta(api, key, i, m.keys) + res = api.Add(res, api.Mul(deltaI, m.values[i])) + } + + return res +} + +// The keys in a DoubleMap must be constant. i.e. known at setup time +type DoubleMap struct { + keys1 []frontend.Variable + keys2 []frontend.Variable + values [][]frontend.Variable +} + +// Get is very inefficient. Do not use outside testing +func (m DoubleMap) Get(api frontend.API, key1, key2 frontend.Variable) frontend.Variable { + deltas1 := make([]frontend.Variable, len(m.keys1)) + deltas2 := make([]frontend.Variable, len(m.keys2)) + + for i := range deltas1 { + deltas1[i] = getDelta(api, key1, i, m.keys1) + } + + for j := range deltas2 { + deltas2[j] = getDelta(api, key2, j, m.keys2) + } + + res := frontend.Variable(0) + + for i := range deltas1 { + for j := range deltas2 { + if m.values[i][j] != nil { + deltaIJ := api.Mul(deltas1[i], deltas2[j], m.values[i][j]) + res = api.Add(res, deltaIJ) + } + } + } + + return res +} + +func register[K comparable](m map[K]int, key K) { + if _, ok := m[key]; !ok { + m[key] = len(m) + } +} + +func orderKeys[K comparable](order map[K]int) (ordered []K) { + ordered = make([]K, len(order)) + for k, i := range order { + ordered[i] = k + } + return +} + +type ElementMap struct { + single Map + double DoubleMap +} + +func ReadMap(in map[string]interface{}) ElementMap { + single := Map{ + keys: make([]frontend.Variable, 0), + values: make([]frontend.Variable, 0), + } + + keys1 := make(map[string]int) + keys2 := make(map[string]int) + + for k, v := range in { + + kSep := strings.Split(k, ",") + switch len(kSep) { + case 1: + single.keys = append(single.keys, k) + single.values = append(single.values, ToVariable(v)) + case 2: + + register(keys1, kSep[0]) + register(keys2, kSep[1]) + + default: + panic("too many keys") + } + } + + vals := make([][]frontend.Variable, len(keys1)) + for i := range vals { + vals[i] = make([]frontend.Variable, len(keys2)) + } + + for k, v := range in { + kSep := strings.Split(k, ",") + if len(kSep) == 2 { + i1 := keys1[kSep[0]] + i2 := keys2[kSep[1]] + vals[i1][i2] = ToVariable(v) + } + } + + double := DoubleMap{ + keys1: ToVariableSlice(orderKeys(keys1)), + keys2: ToVariableSlice(orderKeys(keys2)), + values: vals, + } + + return ElementMap{ + single: single, + double: double, + } +} + +func ToVariable(v interface{}) frontend.Variable { + switch vT := v.(type) { + case float64: + return int(vT) + default: + return v + } +} + +func ToVariableSlice[V any](slice []V) (variableSlice []frontend.Variable) { + variableSlice = make([]frontend.Variable, len(slice)) + for i := range slice { + variableSlice[i] = ToVariable(slice[i]) + } + return +} + +func ToVariableSliceSlice[V any](sliceSlice [][]V) (variableSliceSlice [][]frontend.Variable) { + variableSliceSlice = make([][]frontend.Variable, len(sliceSlice)) + for i := range sliceSlice { + variableSliceSlice[i] = ToVariableSlice(sliceSlice[i]) + } + return +} + +func ToMap(keys1, keys2, values []frontend.Variable) map[string]interface{} { + res := make(map[string]interface{}, len(keys1)) + for i := range keys1 { + str := strconv.Itoa(keys1[i].(int)) + "," + strconv.Itoa(keys2[i].(int)) + res[str] = values[i].(int) + } + return res +} + +var MapCache = make(map[string]ElementMap) // @Tabaie: global bad? + +func ElementMapFromFile(path string) (ElementMap, error) { + path, err := filepath.Abs(path) + if err != nil { + return ElementMap{}, err + } + if h, ok := MapCache[path]; ok { + return h, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var asMap map[string]interface{} + if err = json.Unmarshal(bytes, &asMap); err != nil { + return ElementMap{}, err + } + + res := ReadMap(asMap) + MapCache[path] = res + return res, nil + + } else { + return ElementMap{}, err + } +} + +type MapHash struct { + Map ElementMap + state frontend.Variable + API frontend.API + stateValid bool +} + +func (m *MapHash) Sum() frontend.Variable { + return m.state +} + +func (m *MapHash) Write(data ...frontend.Variable) { + for _, x := range data { + m.write(x) + } +} + +func (m *MapHash) Reset() { + m.stateValid = false +} + +func (m *MapHash) write(x frontend.Variable) { + if m.stateValid { + m.state = m.Map.double.Get(m.API, x, m.state) + } else { + m.state = m.Map.single.Get(m.API, x) + } + m.stateValid = true +} + +func AssertSliceEqual[T comparable](t *testing.T, expected, seen []T) { + assert.Equal(t, len(expected), len(seen)) + for i := range seen { + assert.True(t, expected[i] == seen[i], "@%d: %v != %v", i, expected[i], seen[i]) // assert.Equal is not strict enough when comparing pointers, i.e. it compares what they refer to + } +} + +func SliceEqual[T comparable](expected, seen []T) bool { + if len(expected) != len(seen) { + return false + } + for i := range seen { + if expected[i] != seen[i] { + return false + } + } + return true +} diff --git a/std/utils/test_vectors_utils/test_vector_utils_test.go b/std/utils/test_vectors_utils/test_vector_utils_test.go new file mode 100644 index 0000000000..2193fb789e --- /dev/null +++ b/std/utils/test_vectors_utils/test_vector_utils_test.go @@ -0,0 +1,148 @@ +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" + "github.com/stretchr/testify/assert" + "testing" +) + +type TestSingleMapCircuit struct { + M Map `gnark:"-"` + Values []frontend.Variable +} + +func (c *TestSingleMapCircuit) Define(api frontend.API) error { + + for i, k := range c.M.keys { + v := c.M.Get(api, k) + api.AssertIsEqual(v, c.Values[i]) + } + + return nil +} + +func TestSingleMap(t *testing.T) { + m := map[string]interface{}{ + "1": -2, + "4": 1, + "6": 7, + } + single := ReadMap(m).single + + assignment := TestSingleMapCircuit{ + M: single, + Values: single.values, + } + + circuit := TestSingleMapCircuit{ + M: single, + Values: make([]frontend.Variable, len(m)), // Okay to use the same object? + } + + test.NewAssert(t).ProverSucceeded(&circuit, &assignment, test.WithBackends(backend.GROTH16)) +} + +type TestDoubleMapCircuit struct { + M DoubleMap `gnark:"-"` + Values []frontend.Variable + Keys1 []frontend.Variable `gnark:"-"` + Keys2 []frontend.Variable `gnark:"-"` +} + +func (c *TestDoubleMapCircuit) Define(api frontend.API) error { + + for i := range c.Keys1 { + v := c.M.Get(api, c.Keys1[i], c.Keys2[i]) + api.AssertIsEqual(v, c.Values[i]) + } + + return nil +} + +func TestReadDoubleMap(t *testing.T) { + keys1 := []frontend.Variable{1, 2} + keys2 := []frontend.Variable{1, 0} + values := []frontend.Variable{3, 1} + + for i := 0; i < 100; i++ { + m := ToMap(keys1, keys2, values) + double := ReadMap(m).double + valuesOrdered := [][]frontend.Variable{{3, nil}, {nil, 1}} + + assert.True(t, double.keys1[0] == "1" && double.keys1[1] == "2" || double.keys1[0] == "2" && double.keys1[1] == "1") + assert.True(t, double.keys2[0] == "1" && double.keys2[1] == "0" || double.keys2[0] == "0" && double.keys2[1] == "1") + + if double.keys1[0] != "1" { + valuesOrdered[0], valuesOrdered[1] = valuesOrdered[1], valuesOrdered[0] + } + + if double.keys2[0] != "1" { + valuesOrdered[0][0], valuesOrdered[0][1] = valuesOrdered[0][1], valuesOrdered[0][0] + valuesOrdered[1][0], valuesOrdered[1][1] = valuesOrdered[1][1], valuesOrdered[1][0] + } + + assert.True(t, slice2Eq(valuesOrdered, double.values)) + + } + +} + +func slice2Eq(s1, s2 [][]frontend.Variable) bool { + if len(s1) != len(s2) { + return false + } + for i := range s1 { + if !sliceEq(s1[i], s2[i]) { + return false + } + } + return true +} + +func sliceEq(s1, s2 []frontend.Variable) bool { + if len(s1) != len(s2) { + return false + } + for i := range s1 { + if s1[i] != s2[i] { + return false + } + } + return true +} + +func TestDoubleMap(t *testing.T) { + keys1 := []frontend.Variable{1, 5, 5, 3} + keys2 := []frontend.Variable{1, -5, 4, 4} + values := []frontend.Variable{0, 2, 3, 0} + + m := ToMap(keys1, keys2, values) + double := ReadMap(m).double + + fmt.Println(double) + + assignment := TestDoubleMapCircuit{ + M: double, + Values: values, + Keys1: keys1, + Keys2: keys2, + } + + circuit := TestDoubleMapCircuit{ + M: double, + Keys1: keys1, + Keys2: keys2, + Values: make([]frontend.Variable, len(m)), // Okay to use the same object? + } + + test.NewAssert(t).ProverSucceeded(&circuit, &assignment, test.WithBackends(backend.GROTH16)) +} + +func TestDoubleMapManyTimes(t *testing.T) { + for i := 0; i < 100; i++ { + TestDoubleMap(t) + } +} diff --git a/test/engine.go b/test/engine.go index 3a7c348f8a..afdcd1457c 100644 --- a/test/engine.go +++ b/test/engine.go @@ -678,3 +678,7 @@ func (e *engine) ToCanonicalVariable(v frontend.Variable) frontend.CanonicalVari r := e.toBigInt(v) return wrappedBigInt{r} } + +func (e *engine) SetGkrInfo(info constraint.GkrInfo) error { + return fmt.Errorf("not implemented") +}