From 9d44a4d164215959c0e6306967a2b6c9e302288a Mon Sep 17 00:00:00 2001 From: Arya Tabaie Date: Fri, 30 Jun 2023 15:45:35 -0500 Subject: [PATCH] build: generify system.Equal --- constraint/bls12-377/system.go | 14 ++-- constraint/bls12-381/system.go | 14 ++-- constraint/bls24-315/system.go | 14 ++-- constraint/bls24-317/system.go | 14 ++-- constraint/bw6-633/system.go | 14 ++-- constraint/bw6-761/system.go | 14 ++-- constraint/tinyfield/system.go | 14 ++-- .../template/representations/system.go.tmpl | 14 ++-- std/gkr/api_test.go | 64 ++----------------- std/gkr/placeholder_hints.go | 12 ++++ 10 files changed, 48 insertions(+), 140 deletions(-) diff --git a/constraint/bls12-377/system.go b/constraint/bls12-377/system.go index a86d64e63b..a28aba3286 100644 --- a/constraint/bls12-377/system.go +++ b/constraint/bls12-377/system.go @@ -382,18 +382,12 @@ func (s *system) AddGkr(gkr constraint.GkrInfo) error { } func (s *system) Equal(other constraint.ConstraintSystem) bool { + if !s.GkrInfo.Is() { + return reflect.DeepEqual(s, other) // fast track + } if o, ok := other.(*system); !ok { return false } else { - oHints := o.MHintsDependencies - - if match := constraint.HintsEqual(s.MHintsDependencies, oHints); !match { - return false - } - - o.MHintsDependencies = s.MHintsDependencies - match := reflect.DeepEqual(s, o) - o.MHintsDependencies = oHints - return match + return reflect.DeepEqual(s.field, o.field) && reflect.DeepEqual(s.CoeffTable, o.CoeffTable) && constraint.SystemEqual(s.System, o.System) } } diff --git a/constraint/bls12-381/system.go b/constraint/bls12-381/system.go index c3ac1d1794..d083492389 100644 --- a/constraint/bls12-381/system.go +++ b/constraint/bls12-381/system.go @@ -382,18 +382,12 @@ func (s *system) AddGkr(gkr constraint.GkrInfo) error { } func (s *system) Equal(other constraint.ConstraintSystem) bool { + if !s.GkrInfo.Is() { + return reflect.DeepEqual(s, other) // fast track + } if o, ok := other.(*system); !ok { return false } else { - oHints := o.MHintsDependencies - - if match := constraint.HintsEqual(s.MHintsDependencies, oHints); !match { - return false - } - - o.MHintsDependencies = s.MHintsDependencies - match := reflect.DeepEqual(s, o) - o.MHintsDependencies = oHints - return match + return reflect.DeepEqual(s.field, o.field) && reflect.DeepEqual(s.CoeffTable, o.CoeffTable) && constraint.SystemEqual(s.System, o.System) } } diff --git a/constraint/bls24-315/system.go b/constraint/bls24-315/system.go index 7e9f4189a3..5ce3628f62 100644 --- a/constraint/bls24-315/system.go +++ b/constraint/bls24-315/system.go @@ -382,18 +382,12 @@ func (s *system) AddGkr(gkr constraint.GkrInfo) error { } func (s *system) Equal(other constraint.ConstraintSystem) bool { + if !s.GkrInfo.Is() { + return reflect.DeepEqual(s, other) // fast track + } if o, ok := other.(*system); !ok { return false } else { - oHints := o.MHintsDependencies - - if match := constraint.HintsEqual(s.MHintsDependencies, oHints); !match { - return false - } - - o.MHintsDependencies = s.MHintsDependencies - match := reflect.DeepEqual(s, o) - o.MHintsDependencies = oHints - return match + return reflect.DeepEqual(s.field, o.field) && reflect.DeepEqual(s.CoeffTable, o.CoeffTable) && constraint.SystemEqual(s.System, o.System) } } diff --git a/constraint/bls24-317/system.go b/constraint/bls24-317/system.go index 2754bb6ff3..124dc1e839 100644 --- a/constraint/bls24-317/system.go +++ b/constraint/bls24-317/system.go @@ -382,18 +382,12 @@ func (s *system) AddGkr(gkr constraint.GkrInfo) error { } func (s *system) Equal(other constraint.ConstraintSystem) bool { + if !s.GkrInfo.Is() { + return reflect.DeepEqual(s, other) // fast track + } if o, ok := other.(*system); !ok { return false } else { - oHints := o.MHintsDependencies - - if match := constraint.HintsEqual(s.MHintsDependencies, oHints); !match { - return false - } - - o.MHintsDependencies = s.MHintsDependencies - match := reflect.DeepEqual(s, o) - o.MHintsDependencies = oHints - return match + return reflect.DeepEqual(s.field, o.field) && reflect.DeepEqual(s.CoeffTable, o.CoeffTable) && constraint.SystemEqual(s.System, o.System) } } diff --git a/constraint/bw6-633/system.go b/constraint/bw6-633/system.go index 7be8a5db79..06767f667e 100644 --- a/constraint/bw6-633/system.go +++ b/constraint/bw6-633/system.go @@ -382,18 +382,12 @@ func (s *system) AddGkr(gkr constraint.GkrInfo) error { } func (s *system) Equal(other constraint.ConstraintSystem) bool { + if !s.GkrInfo.Is() { + return reflect.DeepEqual(s, other) // fast track + } if o, ok := other.(*system); !ok { return false } else { - oHints := o.MHintsDependencies - - if match := constraint.HintsEqual(s.MHintsDependencies, oHints); !match { - return false - } - - o.MHintsDependencies = s.MHintsDependencies - match := reflect.DeepEqual(s, o) - o.MHintsDependencies = oHints - return match + return reflect.DeepEqual(s.field, o.field) && reflect.DeepEqual(s.CoeffTable, o.CoeffTable) && constraint.SystemEqual(s.System, o.System) } } diff --git a/constraint/bw6-761/system.go b/constraint/bw6-761/system.go index ab2e941cb4..0955e9b8fc 100644 --- a/constraint/bw6-761/system.go +++ b/constraint/bw6-761/system.go @@ -382,18 +382,12 @@ func (s *system) AddGkr(gkr constraint.GkrInfo) error { } func (s *system) Equal(other constraint.ConstraintSystem) bool { + if !s.GkrInfo.Is() { + return reflect.DeepEqual(s, other) // fast track + } if o, ok := other.(*system); !ok { return false } else { - oHints := o.MHintsDependencies - - if match := constraint.HintsEqual(s.MHintsDependencies, oHints); !match { - return false - } - - o.MHintsDependencies = s.MHintsDependencies - match := reflect.DeepEqual(s, o) - o.MHintsDependencies = oHints - return match + return reflect.DeepEqual(s.field, o.field) && reflect.DeepEqual(s.CoeffTable, o.CoeffTable) && constraint.SystemEqual(s.System, o.System) } } diff --git a/constraint/tinyfield/system.go b/constraint/tinyfield/system.go index e3792d79fb..3c9f98dc7c 100644 --- a/constraint/tinyfield/system.go +++ b/constraint/tinyfield/system.go @@ -378,18 +378,12 @@ func (s *system) AddGkr(gkr constraint.GkrInfo) error { } func (s *system) Equal(other constraint.ConstraintSystem) bool { + if !s.GkrInfo.Is() { + return reflect.DeepEqual(s, other) // fast track + } if o, ok := other.(*system); !ok { return false } else { - oHints := o.MHintsDependencies - - if match := constraint.HintsEqual(s.MHintsDependencies, oHints); !match { - return false - } - - o.MHintsDependencies = s.MHintsDependencies - match := reflect.DeepEqual(s, o) - o.MHintsDependencies = oHints - return match + return reflect.DeepEqual(s.field, o.field) && reflect.DeepEqual(s.CoeffTable, o.CoeffTable) && constraint.SystemEqual(s.System, o.System) } } diff --git a/internal/generator/backend/template/representations/system.go.tmpl b/internal/generator/backend/template/representations/system.go.tmpl index 65c13b1b6c..ec718d724e 100644 --- a/internal/generator/backend/template/representations/system.go.tmpl +++ b/internal/generator/backend/template/representations/system.go.tmpl @@ -379,18 +379,12 @@ func (s *system) AddGkr(gkr constraint.GkrInfo) error { } func (s *system) Equal(other constraint.ConstraintSystem) bool { + if !s.GkrInfo.Is() { + return reflect.DeepEqual(s, other) // fast track + } if o, ok := other.(*system); !ok { return false } else { - oHints := o.MHintsDependencies - - if match := constraint.HintsEqual(s.MHintsDependencies, oHints); !match { - return false - } - - o.MHintsDependencies = s.MHintsDependencies - match := reflect.DeepEqual(s, o) - o.MHintsDependencies = oHints - return match + return reflect.DeepEqual(s.field, o.field) && reflect.DeepEqual(s.CoeffTable, o.CoeffTable) && constraint.SystemEqual(s.System, o.System) } } \ No newline at end of file diff --git a/std/gkr/api_test.go b/std/gkr/api_test.go index b7870efcd2..c209042805 100644 --- a/std/gkr/api_test.go +++ b/std/gkr/api_test.go @@ -2,9 +2,7 @@ package gkr import ( "fmt" - "github.com/consensys/gnark-crypto/kzg" "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/plonk" bn254r1cs "github.com/consensys/gnark/constraint/bn254" "github.com/consensys/gnark/test" "github.com/stretchr/testify/require" @@ -19,11 +17,9 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254/fr/gkr" bn254MiMC "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc" "github.com/consensys/gnark/backend/groth16" - "github.com/consensys/gnark/backend/witness" "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" - "github.com/consensys/gnark/frontend/cs/scs" stdHash "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/std/hash/mimc" test_vector_utils "github.com/consensys/gnark/std/utils/test_vectors_utils" @@ -74,9 +70,7 @@ func TestDoubleNoDependencyCircuit(t *testing.T) { assignment := doubleNoDependencyCircuit{X: xValues} circuit := doubleNoDependencyCircuit{X: make([]frontend.Variable, len(xValues)), hashName: hashName} - test.NewAssert(t).SolvingSucceeded(&circuit, &assignment, test.WithBackends(backend.GROTH16), test.WithCurves(ecc.BN254)) - //testGroth16(t, &circuit, &assignment) - //testPlonk(t, &circuit, &assignment) + test.NewAssert(t).ProverSucceeded(&circuit, &assignment, test.WithCurves(ecc.BN254)) } } } @@ -120,8 +114,7 @@ func TestSqNoDependencyCircuit(t *testing.T) { for _, hashName := range hashes { assignment := sqNoDependencyCircuit{X: xValues} circuit := sqNoDependencyCircuit{X: make([]frontend.Variable, len(xValues)), hashName: hashName} - testGroth16(t, &circuit, &assignment) - testPlonk(t, &circuit, &assignment) + test.NewAssert(t).ProverSucceeded(&circuit, &assignment, test.WithCurves(ecc.BN254)) } } } @@ -183,8 +176,7 @@ func TestMulNoDependency(t *testing.T) { hashName: hashName, } - testGroth16(t, &circuit, &assignment) - testPlonk(t, &circuit, &assignment) + test.NewAssert(t).ProverSucceeded(&circuit, &assignment, test.WithCurves(ecc.BN254)) } } } @@ -239,8 +231,7 @@ func TestSolveMulWithDependency(t *testing.T) { } circuit := mulWithDependencyCircuit{Y: make([]frontend.Variable, len(assignment.Y)), hashName: "-20"} - testGroth16(t, &circuit, &assignment) - testPlonk(t, &circuit, &assignment) + test.NewAssert(t).ProverSucceeded(&circuit, &assignment, test.WithCurves(ecc.BN254)) } func TestApiMul(t *testing.T) { @@ -380,53 +371,6 @@ func (c *benchMiMCMerkleTreeCircuit) Define(api frontend.API) error { return solution.Verify("-20", challenge) } -func testGroth16(t *testing.T, circuit, assignment frontend.Circuit) { - cs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, circuit, frontend.WithCompressThreshold(compressThreshold)) - require.NoError(t, err) - var ( - fullWitness witness.Witness - publicWitness witness.Witness - pk groth16.ProvingKey - vk groth16.VerifyingKey - proof groth16.Proof - ) - fullWitness, err = frontend.NewWitness(assignment, ecc.BN254.ScalarField()) - require.NoError(t, err) - publicWitness, err = fullWitness.Public() - require.NoError(t, err) - pk, vk, err = groth16.Setup(cs) - require.NoError(t, err) - proof, err = groth16.Prove(cs, pk, fullWitness) - require.NoError(t, err) - err = groth16.Verify(proof, vk, publicWitness) - require.NoError(t, err) -} - -func testPlonk(t *testing.T, circuit, assignment frontend.Circuit) { - cs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, circuit, frontend.WithCompressThreshold(compressThreshold)) - require.NoError(t, err) - var ( - fullWitness witness.Witness - publicWitness witness.Witness - pk plonk.ProvingKey - vk plonk.VerifyingKey - proof plonk.Proof - kzgSrs kzg.SRS - ) - fullWitness, err = frontend.NewWitness(assignment, ecc.BN254.ScalarField()) - require.NoError(t, err) - publicWitness, err = fullWitness.Public() - require.NoError(t, err) - kzgSrs, err = test.NewKZGSRS(cs) - require.NoError(t, err) - pk, vk, err = plonk.Setup(cs, kzgSrs) - require.NoError(t, err) - proof, err = plonk.Prove(cs, pk, fullWitness) - require.NoError(t, err) - err = plonk.Verify(proof, vk, publicWitness) - require.NoError(t, err) -} - func registerMiMC() { bn254r1cs.HashBuilderRegistry["mimc"] = bn254MiMC.NewMiMC stdHash.BuilderRegistry["mimc"] = func(api frontend.API) (stdHash.FieldHasher, error) { diff --git a/std/gkr/placeholder_hints.go b/std/gkr/placeholder_hints.go index 2d63eb5350..3ba5621445 100644 --- a/std/gkr/placeholder_hints.go +++ b/std/gkr/placeholder_hints.go @@ -71,8 +71,20 @@ func ProveHintPlaceholderGenerator(hashName string, solveHintId, proveHintId sol curve := utils.FieldToCurve(mod) switch curve { + case ecc.BLS12_377: + err = bls12_377.GkrProveHint(hashName, placeholderGkrSolvingData[solveHintId].(*bls12_377.GkrSolvingData))(mod, in, out) + case ecc.BLS12_381: + err = bls12_381.GkrProveHint(hashName, placeholderGkrSolvingData[solveHintId].(*bls12_381.GkrSolvingData))(mod, in, out) + case ecc.BLS24_315: + err = bls24_315.GkrProveHint(hashName, placeholderGkrSolvingData[solveHintId].(*bls24_315.GkrSolvingData))(mod, in, out) + case ecc.BLS24_317: + err = bls24_317.GkrProveHint(hashName, placeholderGkrSolvingData[solveHintId].(*bls24_317.GkrSolvingData))(mod, in, out) case ecc.BN254: err = bn254.GkrProveHint(hashName, placeholderGkrSolvingData[solveHintId].(*bn254.GkrSolvingData))(mod, in, out) + case ecc.BW6_633: + err = bw6_633.GkrProveHint(hashName, placeholderGkrSolvingData[solveHintId].(*bw6_633.GkrSolvingData))(mod, in, out) + case ecc.BW6_761: + err = bw6_761.GkrProveHint(hashName, placeholderGkrSolvingData[solveHintId].(*bw6_761.GkrSolvingData))(mod, in, out) default: err = errors.New("unsupported curve") }