From 08e1a348a7b8138e9de288f807aa3cd886ac2aa8 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Thu, 2 Nov 2023 19:24:09 +0100 Subject: [PATCH] feat: add short-hash backed FS transcript --- std/recursion/wrapped_hash.go | 17 ++++++ std/recursion/wrapped_hash_test.go | 85 +++++++++++++++++++++++++++--- 2 files changed, 96 insertions(+), 6 deletions(-) diff --git a/std/recursion/wrapped_hash.go b/std/recursion/wrapped_hash.go index 1565608576..fff9075c28 100644 --- a/std/recursion/wrapped_hash.go +++ b/std/recursion/wrapped_hash.go @@ -9,6 +9,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" cryptomimc "github.com/consensys/gnark-crypto/hash" "github.com/consensys/gnark/frontend" + fiatshamir "github.com/consensys/gnark/std/fiat-shamir" stdhash "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/std/math/bits" @@ -174,6 +175,22 @@ func NewHash(api frontend.API, target *big.Int, bitmode bool) (stdhash.FieldHash return newHashFromParameter(api, &h, nbBits, bitmode), nil } +// NewTranscript returns a new Fiat-Shamir transcript for computing bound +// challenges. It uses hasher returned by [NewHash] internally and configures +// the transcript to be compatible with gnark-crypto Fiat-Shamir transcript. +func NewTranscript(api frontend.API, target *big.Int, challenges []string) (*fiatshamir.Transcript, error) { + h, err := NewHash(api, target, true) + if err != nil { + return nil, fmt.Errorf("new hash: %w", err) + } + nbBits := target.BitLen() + if nbBits > api.Compiler().FieldBitLen() { + nbBits = api.Compiler().FieldBitLen() + } + fs := fiatshamir.NewTranscript(api, h, challenges, fiatshamir.WithTryBitmode(((nbBits+7)/8)*8-8)) + return fs, nil +} + func (h *shortCircuitHash) Sum() frontend.Variable { // before we compute the digest we have to write the rest of the buffer into // the underlying hash. We know that we have maximum one variable left, as diff --git a/std/recursion/wrapped_hash_test.go b/std/recursion/wrapped_hash_test.go index f5969450f0..1370d66ec6 100644 --- a/std/recursion/wrapped_hash_test.go +++ b/std/recursion/wrapped_hash_test.go @@ -13,12 +13,14 @@ import ( fr_bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" "github.com/consensys/gnark-crypto/ecc/bn254" fr_bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr" + bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761" + cryptofs "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/algebra" "github.com/consensys/gnark/std/algebra/emulated/sw_bn254" + "github.com/consensys/gnark/std/algebra/emulated/sw_bw6761" "github.com/consensys/gnark/std/algebra/native/sw_bls12377" "github.com/consensys/gnark/std/algebra/native/sw_bls24315" - fiatshamir "github.com/consensys/gnark/std/fiat-shamir" "github.com/consensys/gnark/std/recursion" "github.com/consensys/gnark/test" ) @@ -268,18 +270,17 @@ type transcriptCircuit[S algebra.ScalarT, G1El algebra.G1ElementT] struct { } func (c *transcriptCircuit[S, G1El]) Define(api frontend.API) error { - h, err := recursion.NewHash(api, c.target, true) + fs, err := recursion.NewTranscript(api, c.target, c.Challenges[:]) if err != nil { - return fmt.Errorf("new hash: %w", err) + return fmt.Errorf("new transcript: %w", err) } curve, err := algebra.GetCurve[S, G1El](api) if err != nil { return fmt.Errorf("get curve: %w", err) } - fs := fiatshamir.NewTranscript(api, h, c.Challenges[:]...) for i := range c.Points { for j := range c.Points[i] { - if err := fs.Bind(c.Challenges[i], curve.MarshalG1(c.Points[j][i])); err != nil { + if err := fs.Bind(c.Challenges[i], curve.MarshalG1(c.Points[i][j])); err != nil { return fmt.Errorf("bind[%d][%d] %w", i, j, err) } } @@ -295,5 +296,77 @@ func (c *transcriptCircuit[S, G1El]) Define(api frontend.API) error { } func TestTranscriptMarsha(t *testing.T) { - + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + h, err := recursion.NewShort(ecc.BW6_761.ScalarField(), ecc.BLS12_377.ScalarField()) + assert.NoError(err) + challenges := [3]string{"alfa", "beta", "gamma"} + fs := cryptofs.NewTranscript(h, challenges[:]...) + var points [3][3]sw_bls12377.G1Affine + for i := range points { + for j := range points[i] { + var p bls12377.G1Affine + r, err := rand.Int(rand.Reader, ecc.BLS12_377.ScalarField()) + assert.NoError(err) + p.ScalarMultiplicationBase(r) + points[i][j] = sw_bls12377.NewG1Affine(p) + if err := fs.Bind(challenges[i], p.Marshal()); err != nil { + t.Fatal("bind", err) + } + } + } + var expected [3]frontend.Variable + for i := range expected { + res, err := fs.ComputeChallenge(challenges[i]) + assert.NoError(err) + expected[i] = res + } + circuit := &transcriptCircuit[sw_bls12377.Scalar, sw_bls12377.G1Affine]{ + Challenges: challenges, + target: ecc.BLS12_377.ScalarField(), + } + assignment := &transcriptCircuit[sw_bls12377.Scalar, sw_bls12377.G1Affine]{ + Challenges: challenges, + Points: points, + Expected: expected, + target: ecc.BLS12_377.ScalarField(), + } + assert.CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithCurves(ecc.BW6_761), test.NoFuzzing(), test.NoSerializationChecks()) + }, "bw6_761") + assert.Run(func(assert *test.Assert) { + h, err := recursion.NewShort(ecc.BN254.ScalarField(), ecc.BW6_761.ScalarField()) + assert.NoError(err) + challenges := [3]string{"alfa", "beta", "gamma"} + fs := cryptofs.NewTranscript(h, challenges[:]...) + var points [3][3]sw_bw6761.G1Affine + for i := range points { + for j := range points[i] { + var p bw6761.G1Affine + r, err := rand.Int(rand.Reader, ecc.BW6_761.ScalarField()) + assert.NoError(err) + p.ScalarMultiplicationBase(r) + points[i][j] = sw_bw6761.NewG1Affine(p) + if err := fs.Bind(challenges[i], p.Marshal()); err != nil { + t.Fatal("bind", err) + } + } + } + var expected [3]frontend.Variable + for i := range expected { + res, err := fs.ComputeChallenge(challenges[i]) + assert.NoError(err) + expected[i] = res + } + circuit := &transcriptCircuit[sw_bw6761.Scalar, sw_bw6761.G1Affine]{ + Challenges: challenges, + target: ecc.BW6_761.ScalarField(), + } + assignment := &transcriptCircuit[sw_bw6761.Scalar, sw_bw6761.G1Affine]{ + Challenges: challenges, + Points: points, + Expected: expected, + target: ecc.BW6_761.ScalarField(), + } + assert.CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithCurves(ecc.BN254), test.NoFuzzing(), test.NoSerializationChecks()) + }, "bn254") }